Fix config for vector store (#1637)
This commit is contained in:
@@ -28,13 +28,13 @@ class ChromaDB(VectorStoreBase):
|
||||
path
|
||||
):
|
||||
"""
|
||||
Initialize the Qdrant vector store.
|
||||
Initialize the Chromadb vector store.
|
||||
|
||||
Args:
|
||||
client (QdrantClient, optional): Existing Qdrant client instance.
|
||||
host (str, optional): Host address for Qdrant server.
|
||||
port (int, optional): Port for Qdrant server.
|
||||
path (str, optional): Path for local Qdrant database.
|
||||
client (chromadb.Client, optional): Existing chromadb client instance.
|
||||
host (str, optional): Host address for chromadb server.
|
||||
port (int, optional): Port for chromadb server.
|
||||
path (str, optional): Path for local chromadb database.
|
||||
"""
|
||||
if client:
|
||||
self.client = client
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from chromadb.api.client import Client as ChromaDbClient
|
||||
|
||||
def create_default_config(provider: str):
|
||||
"""Create a default configuration based on the provider."""
|
||||
@@ -16,7 +17,7 @@ def create_default_config(provider: str):
|
||||
class QdrantConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Name of the collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
client: Optional[str] = Field(None, description="Existing Qdrant client instance")
|
||||
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
|
||||
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
||||
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||
@@ -38,10 +39,13 @@ class QdrantConfig(BaseModel):
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ChromaDbConfig(BaseModel):
|
||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||
client: Optional[str] = Field(None, description="Existing ChromaDB client instance")
|
||||
client: Optional[ChromaDbClient] = Field(None, description="Existing ChromaDB client instance")
|
||||
path: Optional[str] = Field(None, description="Path to the database directory")
|
||||
host: Optional[str] = Field(None, description="Database connection remote host")
|
||||
port: Optional[str] = Field(None, description="Database connection remote port")
|
||||
@@ -52,6 +56,9 @@ class ChromaDbConfig(BaseModel):
|
||||
if not path and not (host and port):
|
||||
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
|
||||
return values
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user