From 5837991e5c4f24631e1c5db214d86ad3da7427d2 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 3 Aug 2024 21:48:27 +0530 Subject: [PATCH] Fix config for vector store (#1637) --- mem0/vector_stores/chroma.py | 10 +++++----- mem0/vector_stores/configs.py | 13 ++++++++++--- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index 3b3b78bd..1399f08b 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -28,13 +28,13 @@ class ChromaDB(VectorStoreBase): path ): """ - Initialize the Qdrant vector store. + Initialize the Chromadb vector store. Args: - client (QdrantClient, optional): Existing Qdrant client instance. - host (str, optional): Host address for Qdrant server. - port (int, optional): Port for Qdrant server. - path (str, optional): Path for local Qdrant database. + client (chromadb.Client, optional): Existing chromadb client instance. + host (str, optional): Host address for chromadb server. + port (int, optional): Port for chromadb server. + path (str, optional): Path for local chromadb database. """ if client: self.client = client diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index abfb7e36..6c647de3 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -1,7 +1,8 @@ from typing import Optional from pydantic import BaseModel, Field, field_validator, model_validator - +from qdrant_client import QdrantClient +from chromadb.api.client import Client as ChromaDbClient def create_default_config(provider: str): """Create a default configuration based on the provider.""" @@ -16,7 +17,7 @@ def create_default_config(provider: str): class QdrantConfig(BaseModel): collection_name: str = Field("mem0", description="Name of the collection") embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") - client: Optional[str] = Field(None, description="Existing Qdrant client instance") + client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance") host: Optional[str] = Field(None, description="Host address for Qdrant server") port: Optional[int] = Field(None, description="Port for Qdrant server") path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") @@ -38,10 +39,13 @@ class QdrantConfig(BaseModel): ) return values + class Config: + arbitrary_types_allowed = True + class ChromaDbConfig(BaseModel): collection_name: str = Field("mem0", description="Default name for the collection") - client: Optional[str] = Field(None, description="Existing ChromaDB client instance") + client: Optional[ChromaDbClient] = Field(None, description="Existing ChromaDB client instance") path: Optional[str] = Field(None, description="Path to the database directory") host: Optional[str] = Field(None, description="Database connection remote host") port: Optional[str] = Field(None, description="Database connection remote port") @@ -52,6 +56,9 @@ class ChromaDbConfig(BaseModel): if not path and not (host and port): raise ValueError("Either 'host' and 'port' or 'path' must be provided.") return values + + class Config: + arbitrary_types_allowed = True class VectorStoreConfig(BaseModel):