Fix config for vector store (#1637)
This commit is contained in:
@@ -28,13 +28,13 @@ class ChromaDB(VectorStoreBase):
|
|||||||
path
|
path
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Qdrant vector store.
|
Initialize the Chromadb vector store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client (QdrantClient, optional): Existing Qdrant client instance.
|
client (chromadb.Client, optional): Existing chromadb client instance.
|
||||||
host (str, optional): Host address for Qdrant server.
|
host (str, optional): Host address for chromadb server.
|
||||||
port (int, optional): Port for Qdrant server.
|
port (int, optional): Port for chromadb server.
|
||||||
path (str, optional): Path for local Qdrant database.
|
path (str, optional): Path for local chromadb database.
|
||||||
"""
|
"""
|
||||||
if client:
|
if client:
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from chromadb.api.client import Client as ChromaDbClient
|
||||||
|
|
||||||
def create_default_config(provider: str):
|
def create_default_config(provider: str):
|
||||||
"""Create a default configuration based on the provider."""
|
"""Create a default configuration based on the provider."""
|
||||||
@@ -16,7 +17,7 @@ def create_default_config(provider: str):
|
|||||||
class QdrantConfig(BaseModel):
|
class QdrantConfig(BaseModel):
|
||||||
collection_name: str = Field("mem0", description="Name of the collection")
|
collection_name: str = Field("mem0", description="Name of the collection")
|
||||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||||
client: Optional[str] = Field(None, description="Existing Qdrant client instance")
|
client: Optional[QdrantClient] = Field(None, description="Existing Qdrant client instance")
|
||||||
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
host: Optional[str] = Field(None, description="Host address for Qdrant server")
|
||||||
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
port: Optional[int] = Field(None, description="Port for Qdrant server")
|
||||||
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
path: Optional[str] = Field("/tmp/qdrant", description="Path for local Qdrant database")
|
||||||
@@ -38,10 +39,13 @@ class QdrantConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class ChromaDbConfig(BaseModel):
|
class ChromaDbConfig(BaseModel):
|
||||||
collection_name: str = Field("mem0", description="Default name for the collection")
|
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||||
client: Optional[str] = Field(None, description="Existing ChromaDB client instance")
|
client: Optional[ChromaDbClient] = Field(None, description="Existing ChromaDB client instance")
|
||||||
path: Optional[str] = Field(None, description="Path to the database directory")
|
path: Optional[str] = Field(None, description="Path to the database directory")
|
||||||
host: Optional[str] = Field(None, description="Database connection remote host")
|
host: Optional[str] = Field(None, description="Database connection remote host")
|
||||||
port: Optional[str] = Field(None, description="Database connection remote port")
|
port: Optional[str] = Field(None, description="Database connection remote port")
|
||||||
@@ -52,6 +56,9 @@ class ChromaDbConfig(BaseModel):
|
|||||||
if not path and not (host and port):
|
if not path and not (host and port):
|
||||||
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
|
raise ValueError("Either 'host' and 'port' or 'path' must be provided.")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreConfig(BaseModel):
|
class VectorStoreConfig(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user