Fix config for vector store (#1637)

This commit is contained in:
Dev Khant
2024-08-03 21:48:27 +05:30
committed by GitHub
parent 81b4431c9b
commit 5837991e5c
2 changed files with 15 additions and 8 deletions

View File

@@ -28,13 +28,13 @@ class ChromaDB(VectorStoreBase):
path path
): ):
""" """
Initialize the Qdrant vector store. Initialize the Chromadb vector store.
Args: Args:
client (QdrantClient, optional): Existing Qdrant client instance. client (chromadb.Client, optional): Existing chromadb client instance.
host (str, optional): Host address for Qdrant server. host (str, optional): Host address for chromadb server.
port (int, optional): Port for Qdrant server. port (int, optional): Port for chromadb server.
path (str, optional): Path for local Qdrant database. path (str, optional): Path for local chromadb database.
""" """
if client: if client:
self.client = client self.client = client

View File

@@ -1,7 +1,8 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from qdrant_client import QdrantClient
from chromadb.api.client import Client as ChromaDbClient
def create_default_config(provider: str): def create_default_config(provider: str):
"""Create a default configuration based on the provider.""" """Create a default configuration based on the provider."""
@@ -16,7 +17,7 @@ def create_default_config(provider: str):
class QdrantConfig(BaseModel): class QdrantConfig(BaseModel):
collection_name: str = Field("mem0", description="Name of the collection") collection_name: str = Field("mem0", description="Name of the collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
client: Optional[str] = Field(None, description="Existing Qdrant client instance") client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
host: Optional[str] = Field(None, description="Host address for Qdrant server") host: Optional[str] = Field(None, description="Host address for Qdrant server")
port: Optional[int] = Field(None, description="Port for Qdrant server") port: Optional[int] = Field(None, description="Port for Qdrant server")
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database") path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
@@ -38,10 +39,13 @@ class QdrantConfig(BaseModel):
) )
return values return values
class Config:
arbitrary_types_allowed = True
class ChromaDbConfig(BaseModel): class ChromaDbConfig(BaseModel):
collection_name: str = Field("mem0", description="Default name for the collection") collection_name: str = Field("mem0", description="Default name for the collection")
client: Optional[str] = Field(None, description="Existing ChromaDB client instance") client: Optional[ChromaDbClient] = Field(None, description="Existing ChromaDB client instance")
path: Optional[str] = Field(None, description="Path to the database directory") path: Optional[str] = Field(None, description="Path to the database directory")
host: Optional[str] = Field(None, description="Database connection remote host") host: Optional[str] = Field(None, description="Database connection remote host")
port: Optional[str] = Field(None, description="Database connection remote port") port: Optional[str] = Field(None, description="Database connection remote port")
@@ -52,6 +56,9 @@ class ChromaDbConfig(BaseModel):
if not path and not (host and port): if not path and not (host and port):
raise ValueError("Either 'host' and 'port' or 'path' must be provided.") raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
return values return values
class Config:
arbitrary_types_allowed = True
class VectorStoreConfig(BaseModel): class VectorStoreConfig(BaseModel):