Add ChromaDB support (#1612)

This commit is contained in:
Dev Khant
2024-08-01 22:16:35 +05:30
committed by GitHub
parent e585d3c1cc
commit 45ae1f0313
9 changed files with 452 additions and 148 deletions

View File

@@ -1,4 +1,3 @@
import json
import logging
import os
import time
@@ -21,8 +20,7 @@ from mem0.memory.utils import get_update_memory_messages
from mem0.vector_stores.configs import VectorStoreConfig
from mem0.llms.configs import LlmConfig
from mem0.embeddings.configs import EmbedderConfig
from mem0.vector_stores.qdrant import Qdrant
from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
# Setup user config
setup_config()
@@ -57,37 +55,17 @@ class MemoryConfig(BaseModel):
description="Path to the history database",
default=os.path.join(mem0_dir, "history.db"),
)
collection_name: str = Field(default="mem0", description="Name of the collection")
embedding_model_dims: int = Field(
default=1536, description="Dimensions of the embedding model"
)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider)
# Initialize the appropriate vector store based on the configuration
vector_store_config = self.config.vector_store.config
if self.config.vector_store.provider == "qdrant":
self.vector_store = Qdrant(
host=vector_store_config.host,
port=vector_store_config.port,
path=vector_store_config.path,
url=vector_store_config.url,
api_key=vector_store_config.api_key,
)
else:
raise ValueError(
f"Unsupported vector store type: {self.config.vector_store_type}"
)
self.vector_store = VectorStoreFactory.create(self.config.vector_store.provider, self.config.vector_store.config)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.collection_name
self.vector_store.create_col(
name=self.collection_name, vector_size=self.embedding_model.dims
)
self.collection_name = self.config.vector_store.config.collection_name if "collection_name" in self.config.vector_store.config else "mem0"
capture_event("mem0.init", self)
@classmethod