From 5aa7bedabe26a4da724d9fa9e3b0f78b366c053f Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sun, 4 Aug 2024 00:07:15 +0530 Subject: [PATCH] Handle chromadb dep and version bump (#1638) --- docs/components/vectordb.mdx | 2 +- mem0/configs/vector_stores/__init__.py | 0 mem0/configs/vector_stores/chroma.py | 26 +++++ mem0/configs/vector_stores/qdrant.py | 34 +++++++ mem0/utils/factory.py | 2 +- mem0/vector_stores/configs.py | 126 ++++++------------------- pyproject.toml | 2 +- 7 files changed, 93 insertions(+), 99 deletions(-) create mode 100644 mem0/configs/vector_stores/__init__.py create mode 100644 mem0/configs/vector_stores/chroma.py create mode 100644 mem0/configs/vector_stores/qdrant.py diff --git a/docs/components/vectordb.mdx b/docs/components/vectordb.mdx index 0631e26d..051177b2 100644 --- a/docs/components/vectordb.mdx +++ b/docs/components/vectordb.mdx @@ -51,7 +51,7 @@ from mem0 import Memory config = { "vectordb": { - "provider": "chromadb", + "provider": "chroma", "config": { "collection_name": "test", "path": "db", diff --git a/mem0/configs/vector_stores/__init__.py b/mem0/configs/vector_stores/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mem0/configs/vector_stores/chroma.py b/mem0/configs/vector_stores/chroma.py new file mode 100644 index 00000000..ad4d7a29 --- /dev/null +++ b/mem0/configs/vector_stores/chroma.py @@ -0,0 +1,26 @@ +from typing import Optional,ClassVar + +from pydantic import BaseModel, Field, field_validator, model_validator + +class ChromaDbConfig(BaseModel): + try: + from chromadb.api.client import Client + except ImportError: + raise ImportError("Chromadb requires extra dependencies. Install with `pip install chromadb`") from None + Client: ClassVar[type] = Client + + collection_name: str = Field("mem0", description="Default name for the collection") + client: Optional[Client] = Field(None, description="Existing ChromaDB client instance") + path: Optional[str] = Field(None, description="Path to the database directory") + host: Optional[str] = Field(None, description="Database connection remote host") + port: Optional[str] = Field(None, description="Database connection remote port") + + @model_validator(mode="before") + def check_host_port_or_path(cls, values): + host, port, path = values.get("host"), values.get("port"), values.get("path") + if not path and not (host and port): + raise ValueError("Either 'host' and 'port' or 'path' must be provided.") + return values + + class Config: + arbitrary_types_allowed = True \ No newline at end of file diff --git a/mem0/configs/vector_stores/qdrant.py b/mem0/configs/vector_stores/qdrant.py new file mode 100644 index 00000000..f4aa2fd9 --- /dev/null +++ b/mem0/configs/vector_stores/qdrant.py @@ -0,0 +1,34 @@ +from typing import Optional,ClassVar + +from pydantic import BaseModel, Field, field_validator, model_validator + +class QdrantConfig(BaseModel): + from qdrant_client import QdrantClient + QdrantClient: ClassVar[type] = QdrantClient + + collection_name: str = Field("mem0", description="Name of the collection") + embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") + client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance") + host: Optional[str] = Field(None, description="Host address for Qdrant server") + port: Optional[int] = Field(None, description="Port for Qdrant server") + path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") + url: Optional[str] = Field(None, description="Full URL for Qdrant server") + api_key: Optional[str] = Field(None, description="API key for Qdrant server") + + @model_validator(mode="before") + def check_host_port_or_path(cls, values): + host, port, path, url, api_key = ( + values.get("host"), + values.get("port"), + values.get("path"), + values.get("url"), + values.get("api_key"), + ) + if not path and not (host and port) and not (url and api_key): + raise ValueError( + "Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided." + ) + return values + + class Config: + arbitrary_types_allowed = True diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index f518b917..2224d29f 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -49,7 +49,7 @@ class EmbedderFactory: class VectorStoreFactory: provider_to_class = { "qdrant": "mem0.vector_stores.qdrant.Qdrant", - "chromadb": "mem0.vector_stores.chroma.ChromaDB", + "chroma": "mem0.vector_stores.chroma.ChromaDB", } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 6c647de3..628684a5 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -1,108 +1,42 @@ -from typing import Optional - -from pydantic import BaseModel, Field, field_validator, model_validator -from qdrant_client import QdrantClient -from chromadb.api.client import Client as ChromaDbClient - -def create_default_config(provider: str): - """Create a default configuration based on the provider.""" - if provider == "qdrant": - return QdrantConfig(path="/tmp/qdrant") - elif provider == "chromadb": - return ChromaDbConfig(path="/tmp/chromadb") - else: - raise ValueError(f"Unsupported vector store provider: {provider}") - - -class QdrantConfig(BaseModel): - collection_name: str = Field("mem0", description="Name of the collection") - embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") - client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance") - host: Optional[str] = Field(None, description="Host address for Qdrant server") - port: Optional[int] = Field(None, description="Port for Qdrant server") - path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") - url: Optional[str] = Field(None, description="Full URL for Qdrant server") - api_key: Optional[str] = Field(None, description="API key for Qdrant server") - - @model_validator(mode="before") - def check_host_port_or_path(cls, values): - host, port, path, url, api_key = ( - values.get("host"), - values.get("port"), - values.get("path"), - values.get("url"), - values.get("api_key"), - ) - if not path and not (host and port) and not (url and api_key): - raise ValueError( - "Either 'host' and 'port' or 'url' and 'api_key' or 'path' must be provided." - ) - return values - - class Config: - arbitrary_types_allowed = True - - -class ChromaDbConfig(BaseModel): - collection_name: str = Field("mem0", description="Default name for the collection") - client: Optional[ChromaDbClient] = Field(None, description="Existing ChromaDB client instance") - path: Optional[str] = Field(None, description="Path to the database directory") - host: Optional[str] = Field(None, description="Database connection remote host") - port: Optional[str] = Field(None, description="Database connection remote port") - - @model_validator(mode="before") - def check_host_port_or_path(cls, values): - host, port, path = values.get("host"), values.get("port"), values.get("path") - if not path and not (host and port): - raise ValueError("Either 'host' and 'port' or 'path' must be provided.") - return values - - class Config: - arbitrary_types_allowed = True - +from typing import Optional, Dict, Type +from pydantic import BaseModel, Field, model_validator class VectorStoreConfig(BaseModel): provider: str = Field( - description="Provider of the vector store (e.g., 'qdrant', 'chromadb', 'elasticsearch')", + description="Provider of the vector store (e.g., 'qdrant', 'chromadb')", default="qdrant", ) - config: Optional[dict] = Field( + config: Optional[Dict] = Field( description="Configuration for the specific vector store", default=None ) - @field_validator("config") - def validate_config(cls, v, values): - provider = values.data.get("provider") - - if v is None: - return create_default_config(provider) - - if isinstance(v, dict): - if provider == "qdrant": - if "path" not in v: - v["path"] = "/tmp/qdrant" - return QdrantConfig(**v) - elif provider == "chromadb": - if "path" not in v: - v["path"] = "/tmp/chromadb" - return ChromaDbConfig(**v) - - return v + _provider_configs: Dict[str, str] = { + "qdrant": "QdrantConfig", + "chroma": "ChromaDbConfig" + } @model_validator(mode="after") - def ensure_config_type(cls, values): - provider = values.provider - config = values.config - + def validate_and_create_config(self) -> 'VectorStoreConfig': + provider = self.provider + config = self.config + + if provider not in self._provider_configs: + raise ValueError(f"Unsupported vector store provider: {provider}") + + module = __import__(f"mem0.configs.vector_stores.{provider}", fromlist=[self._provider_configs[provider]]) + config_class = getattr(module, self._provider_configs[provider]) + if config is None: - values.config = create_default_config(provider) - elif isinstance(config, dict): - if provider == "qdrant": - values.config = QdrantConfig(**config) - elif provider == "chromadb": - values.config = ChromaDbConfig(**config) - elif not isinstance(config, (QdrantConfig, ChromaDbConfig)): - raise ValueError(f"Invalid config type for provider {provider}") - - return values \ No newline at end of file + config = {} + + if not isinstance(config, dict): + if not isinstance(config, config_class): + raise ValueError(f"Invalid config type for provider {provider}") + return self + + if "path" not in config: + config["path"] = f"/tmp/{provider}" + + self.config = config_class(**config) + return self \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5cfd0efb..8fb0bb98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.0.12" +version = "0.0.13" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [