From c09c4926a727d70c4ca096d9eb6d99a5896ea818 Mon Sep 17 00:00:00 2001 From: Parshva Daftari <89991302+parshvadaftari@users.noreply.github.com> Date: Thu, 3 Oct 2024 21:30:46 +0530 Subject: [PATCH] [ Refactored ] embedding models and [ Update ] documentation for Gemini model (#1931) --- docs/components/embedders/models/gemini.mdx | 13 ++++--------- mem0/embeddings/gemini.py | 9 ++++++--- mem0/embeddings/huggingface.py | 6 ++---- mem0/embeddings/ollama.py | 6 ++---- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/docs/components/embedders/models/gemini.mdx b/docs/components/embedders/models/gemini.mdx index 0913a8a7..125e3160 100644 --- a/docs/components/embedders/models/gemini.mdx +++ b/docs/components/embedders/models/gemini.mdx @@ -16,16 +16,9 @@ config = { "embedder": { "provider": "gemini", "config": { - "model": "models/text-embedding-004" + "model": "models/text-embedding-004", } - }, - "vector_store": { - "provider": "qdrant", - "config": { - "collection_name": "test", - "embedding_model_dims": 768, - } - }, + } } m = Memory.from_config(config) @@ -39,3 +32,5 @@ Here are the parameters available for configuring Gemini embedder: | Parameter | Description | Default Value | | --- | --- | --- | | `model` | The name of the embedding model to use | `models/text-embedding-004` | +| `embedding_dims` | Dimensions of the embedding model | `768` | +| `api_key` | The Gemini API key | `None` | diff --git a/mem0/embeddings/gemini.py b/mem0/embeddings/gemini.py index 06efde83..7ef429a9 100644 --- a/mem0/embeddings/gemini.py +++ b/mem0/embeddings/gemini.py @@ -9,10 +9,13 @@ from mem0.embeddings.base import EmbeddingBase class GoogleGenAIEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) - if self.config.model is None: - self.config.model = "models/text-embedding-004" # embedding-dim = 768 + + self.config.model = self.config.model or "models/text-embedding-004" + self.config.embedding_dims = self.config.embedding_dims or 768 - genai.configure(api_key=self.config.api_key or os.getenv("GOOGLE_API_KEY")) + api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY") + + genai.configure(api_key=api_key) def embed(self, text): """ diff --git a/mem0/embeddings/huggingface.py b/mem0/embeddings/huggingface.py index 56d6a072..d2bf5b82 100644 --- a/mem0/embeddings/huggingface.py +++ b/mem0/embeddings/huggingface.py @@ -10,13 +10,11 @@ class HuggingFaceEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) - if self.config.model is None: - self.config.model = "multi-qa-MiniLM-L6-cos-v1" + self.config.model = self.config.model or "multi-qa-MiniLM-L6-cos-v1" self.model = SentenceTransformer(self.config.model, **self.config.model_kwargs) - if self.config.embedding_dims is None: - self.config.embedding_dims = self.model.get_sentence_embedding_dimension() + self.config.embedding_dims = self.config.embedding_dims or self.model.get_sentence_embedding_dimension() def embed(self, text): """ diff --git a/mem0/embeddings/ollama.py b/mem0/embeddings/ollama.py index ae00368e..30034365 100644 --- a/mem0/embeddings/ollama.py +++ b/mem0/embeddings/ollama.py @@ -25,10 +25,8 @@ class OllamaEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) - if not self.config.model: - self.config.model = "nomic-embed-text" - if not self.config.embedding_dims: - self.config.embedding_dims = 512 + self.config.model = self.config.model or "nomic-embed-text" + self.config.embedding_dims = self.config.embedding_dims or 512 self.client = Client(host=self.config.ollama_base_url) self._ensure_model_exists()