49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
from typing import Dict, Optional
|
|
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
|
|
class VectorStoreConfig(BaseModel):
|
|
provider: str = Field(
|
|
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
|
|
default="qdrant",
|
|
)
|
|
config: Optional[Dict] = Field(
|
|
description="Configuration for the specific vector store", default=None
|
|
)
|
|
|
|
_provider_configs: Dict[str, str] = {
|
|
"qdrant": "QdrantConfig",
|
|
"chroma": "ChromaDbConfig",
|
|
"pgvector": "PGVectorConfig",
|
|
}
|
|
|
|
@model_validator(mode="after")
|
|
def validate_and_create_config(self) -> "VectorStoreConfig":
|
|
provider = self.provider
|
|
config = self.config
|
|
|
|
if provider not in self._provider_configs:
|
|
raise ValueError(f"Unsupported vector store provider: {provider}")
|
|
|
|
module = __import__(
|
|
f"mem0.configs.vector_stores.{provider}",
|
|
fromlist=[self._provider_configs[provider]],
|
|
)
|
|
config_class = getattr(module, self._provider_configs[provider])
|
|
|
|
if config is None:
|
|
config = {}
|
|
|
|
if not isinstance(config, dict):
|
|
if not isinstance(config, config_class):
|
|
raise ValueError(f"Invalid config type for provider {provider}")
|
|
return self
|
|
|
|
# also check if path in allowed kays for pydantic model, and whether config extra fields are allowed
|
|
if "path" not in config and "path" in config_class.__annotations__:
|
|
config["path"] = f"/tmp/{provider}"
|
|
|
|
self.config = config_class(**config)
|
|
return self
|