From b245309242b38891ee55873eb12a270ca6cb8f89 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Mon, 12 Aug 2024 16:09:01 +0530 Subject: [PATCH] Add embedder docs and config changes (#1684) --- docs/components/embedders.mdx | 64 ++++++++++++++++++++++++++++ docs/mint.json | 4 ++ mem0/configs/embeddings/base.py | 14 ++++-- mem0/configs/llms/base.py | 14 +++--- mem0/configs/vector_stores/chroma.py | 17 ++++++-- mem0/configs/vector_stores/qdrant.py | 23 +++++++--- mem0/embeddings/ollama.py | 2 +- mem0/embeddings/openai.py | 10 ++--- mem0/llms/openai.py | 5 ++- mem0/memory/main.py | 1 + mem0/vector_stores/configs.py | 2 +- pyproject.toml | 2 +- 12 files changed, 130 insertions(+), 28 deletions(-) create mode 100644 docs/components/embedders.mdx diff --git a/docs/components/embedders.mdx b/docs/components/embedders.mdx new file mode 100644 index 00000000..0f7242fa --- /dev/null +++ b/docs/components/embedders.mdx @@ -0,0 +1,64 @@ +--- +title: 🧩 Embedding models +--- + +## Overview + +Mem0 offers support for various embedding models, allowing users to choose the one that best suits their needs. + + + + + + +> When using `Qdrant` as a vector database, ensure you update the `embedding_model_dims` to match the dimensions of the embedding model you are using. + +## OpenAI + +To use OpenAI embedding models, set the `OPENAI_API_KEY` environment variable. You can obtain the OpenAI API key from the [OpenAI Platform](https://platform.openai.com/account/api-keys). + +Example of how to select the desired embedding model: + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "your_api_key" + +config = { + "embedder": { + "provider": "openai", + "config": { + "model": "text-embedding-3-large" + } + } +} + +m = Memory.from_config(config) +m.add("I'm visiting Paris", user_id="john") +``` + +## Ollama + +You can use embedding models from Ollama to run Mem0 locally. + +Here's how to select it: + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "your_api_key" + +config = { + "embedder": { + "provider": "ollama", + "config": { + "model": "mxbai-embed-large" + } + } +} + +m = Memory.from_config(config) +m.add("I'm visiting Paris", user_id="john") +``` diff --git a/docs/mint.json b/docs/mint.json index 6778b0a3..742e855e 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -69,6 +69,10 @@ "group": "Vector Database", "pages": ["components/vectordb"] }, + { + "group": "Embedding Models", + "pages": ["components/embedders"] + }, { "group": "Features", "pages": ["features/openai_compatibility"] diff --git a/mem0/configs/embeddings/base.py b/mem0/configs/embeddings/base.py index 3b504513..79789289 100644 --- a/mem0/configs/embeddings/base.py +++ b/mem0/configs/embeddings/base.py @@ -9,9 +9,12 @@ class BaseEmbedderConfig(ABC): def __init__( self, model: Optional[str] = None, + api_key: Optional[str] = None, embedding_dims: Optional[int] = None, + # Ollama specific - base_url: Optional[str] = None, + ollama_base_url: Optional[str] = None, + # Huggingface specific model_kwargs: Optional[dict] = None ): @@ -20,20 +23,23 @@ class BaseEmbedderConfig(ABC): :param model: Embedding model to use, defaults to None :type model: Optional[str], optional + :param api_key: API key to be use, defaults to None + :type api_key: Optional[str], optional :param embedding_dims: The number of dimensions in the embedding, defaults to None :type embedding_dims: Optional[int], optional - :param base_url: Base URL for the Ollama API, defaults to None - :type base_url: Optional[str], optional + :param ollama_base_url: Base URL for the Ollama API, defaults to None + :type ollama_base_url: Optional[str], optional :param model_kwargs: key-value arguments for the huggingface embedding model, defaults a dict inside init :type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init """ self.model = model + self.api_key = api_key self.embedding_dims = embedding_dims # Ollama specific - self.base_url = base_url + self.ollama_base_url = ollama_base_url # Huggingface specific self.model_kwargs = model_kwargs or {} diff --git a/mem0/configs/llms/base.py b/mem0/configs/llms/base.py index 103c247d..71855b01 100644 --- a/mem0/configs/llms/base.py +++ b/mem0/configs/llms/base.py @@ -10,6 +10,7 @@ class BaseLlmConfig(ABC): self, model: Optional[str] = None, temperature: float = 0, + api_key: Optional[str] = None, max_tokens: int = 3000, top_p: float = 0, top_k: int = 1, @@ -32,6 +33,8 @@ class BaseLlmConfig(ABC): :param temperature: Controls the randomness of the model's output. Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0 :type temperature: float, optional + :param api_key: OpenAI API key to be use, defaults to None + :type api_key: Optional[str], optional :param max_tokens: Controls how many tokens are generated, defaults to 3000 :type max_tokens: int, optional :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse, @@ -39,15 +42,15 @@ class BaseLlmConfig(ABC): :type top_p: float, optional :param top_k: Controls the diversity of words. Higher values make word selection more diverse, defaults to 0 :type top_k: int, optional - :param models: Controls the Openrouter models used, defaults to None + :param models: Openrouter models to use, defaults to None :type models: Optional[list[str]], optional - :param route: Controls the Openrouter route used, defaults to "fallback" + :param route: Openrouter route to be used, defaults to "fallback" :type route: Optional[str], optional - :param openrouter_base_url: Controls the Openrouter base URL used, defaults to "https://openrouter.ai/api/v1" + :param openrouter_base_url: Openrouter base URL to be use, defaults to "https://openrouter.ai/api/v1" :type openrouter_base_url: Optional[str], optional - :param site_url: Controls the Openrouter site URL used, defaults to None + :param site_url: Openrouter site URL to use, defaults to None :type site_url: Optional[str], optional - :param app_name: Controls the Openrouter app name used, defaults to None + :param app_name: Openrouter app name to use, defaults to None :type app_name: Optional[str], optional :param ollama_base_url: The base URL of the LLM, defaults to None :type ollama_base_url: Optional[str], optional @@ -55,6 +58,7 @@ class BaseLlmConfig(ABC): self.model = model self.temperature = temperature + self.api_key = api_key self.max_tokens = max_tokens self.top_p = top_p self.top_k = top_k diff --git a/mem0/configs/vector_stores/chroma.py b/mem0/configs/vector_stores/chroma.py index edafb5c2..86a7ddcf 100644 --- a/mem0/configs/vector_stores/chroma.py +++ b/mem0/configs/vector_stores/chroma.py @@ -1,4 +1,4 @@ -from typing import Optional,ClassVar +from typing import Optional, ClassVar, Dict, Any from pydantic import BaseModel, Field, model_validator @@ -21,6 +21,17 @@ class ChromaDbConfig(BaseModel): if not path and not (host and port): raise ValueError("Either 'host' and 'port' or 'path' must be provided.") return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") + return values - class Config: - arbitrary_types_allowed = True \ No newline at end of file + model_config = { + "arbitrary_types_allowed": True, + } \ No newline at end of file diff --git a/mem0/configs/vector_stores/qdrant.py b/mem0/configs/vector_stores/qdrant.py index 6c40f108..8d716dc8 100644 --- a/mem0/configs/vector_stores/qdrant.py +++ b/mem0/configs/vector_stores/qdrant.py @@ -1,6 +1,5 @@ -from typing import Optional,ClassVar - from pydantic import BaseModel, Field, model_validator +from typing import Optional, ClassVar, Dict, Any class QdrantConfig(BaseModel): from qdrant_client import QdrantClient @@ -14,10 +13,11 @@ class QdrantConfig(BaseModel): path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") url: Optional[str] = Field(None, description="Full URL for Qdrant server") api_key: Optional[str] = Field(None, description="API key for Qdrant server") - on_disk: Optional[bool]= Field(False, description="Enables persistant storage") + on_disk: Optional[bool] = Field(False, description="Enables persistent storage") @model_validator(mode="before") - def check_host_port_or_path(cls, values): + @classmethod + def check_host_port_or_path(cls, values: Dict[str, Any]) -> Dict[str, Any]: host, port, path, url, api_key = ( values.get("host"), values.get("port"), @@ -31,5 +31,16 @@ class QdrantConfig(BaseModel): ) return values - class Config: - arbitrary_types_allowed = True + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError(f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}") + return values + + model_config = { + "arbitrary_types_allowed": True, + } \ No newline at end of file diff --git a/mem0/embeddings/ollama.py b/mem0/embeddings/ollama.py index 904d8506..30d74be8 100644 --- a/mem0/embeddings/ollama.py +++ b/mem0/embeddings/ollama.py @@ -18,7 +18,7 @@ class OllamaEmbedding(EmbeddingBase): if not self.config.embedding_dims: self.config.embedding_dims=512 - self.client = Client(host=self.config.base_url) + self.client = Client(host=self.config.ollama_base_url) self._ensure_model_exists() def _ensure_model_exists(self): diff --git a/mem0/embeddings/openai.py b/mem0/embeddings/openai.py index 446fa478..f878b8c2 100644 --- a/mem0/embeddings/openai.py +++ b/mem0/embeddings/openai.py @@ -1,3 +1,4 @@ +import os from typing import Optional from openai import OpenAI @@ -9,12 +10,11 @@ class OpenAIEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) - if not self.config.model: - self.config.model="text-embedding-3-small" - if not self.config.embedding_dims: - self.config.embedding_dims=1536 + self.config.model = self.config.model or "text-embedding-3-small" + self.config.embedding_dims = self.config.embedding_dims or 1536 - self.client = OpenAI() + api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key + self.client = OpenAI(api_key=api_key) def embed(self, text): """ diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index 57131d7c..f15424eb 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -14,10 +14,11 @@ class OpenAILLM(LLMBase): if not self.config.model: self.config.model="gpt-4o" - if os.environ.get("OPENROUTER_API_KEY"): + if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url) else: - self.client = OpenAI() + api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key + self.client = OpenAI(api_key=api_key) def _parse_response(self, response, tools): """ diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 950959e7..81bbb65f 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -374,6 +374,7 @@ class Memory(MemoryBase): new_metadata = metadata or {} new_metadata["data"] = data + new_metadata["hash"] = existing_memory.payload.get("hash") new_metadata["created_at"] = existing_memory.payload.get("created_at") new_metadata["updated_at"] = datetime.now(pytz.timezone('US/Pacific')).isoformat() diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 0d76705c..28e4e1c4 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field, model_validator class VectorStoreConfig(BaseModel): provider: str = Field( - description="Provider of the vector store (e.g., 'qdrant', 'chromadb')", + description="Provider of the vector store (e.g., 'qdrant', 'chroma')", default="qdrant", ) config: Optional[Dict] = Field( diff --git a/pyproject.toml b/pyproject.toml index c8c2283f..aa243e6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.0.15" +version = "0.0.16" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [