Integrate Mem0 (#1462)

Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
Dev Khant
2024-07-07 00:57:01 +05:30
committed by GitHub
parent bd654e7aac
commit bbe56107fb
11 changed files with 195 additions and 34 deletions

View File

@@ -9,19 +9,24 @@ import requests
import yaml
from tqdm import tqdm
from embedchain.cache import (Config, ExactMatchEvaluation,
SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from mem0 import Mem0
from embedchain.cache import (
Config,
ExactMatchEvaluation,
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
from embedchain.core.db.database import get_session, init_db, setup_engine
from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@@ -55,6 +60,7 @@ class App(EmbedChain):
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
cache_config: CacheConfig = None,
memory_config: Mem0Config = None,
log_level: int = logging.WARN,
):
"""
@@ -95,6 +101,7 @@ class App(EmbedChain):
self.id = None
self.chunker = ChunkerConfig(**chunker) if chunker else None
self.cache_config = cache_config
self.memory_config = memory_config
self.config = config or AppConfig()
self.name = self.config.name
@@ -123,6 +130,11 @@ class App(EmbedChain):
if self.cache_config is not None:
self._init_cache()
# If memory_config is provided, initializing the memory ...
self.mem0_client = None
if self.memory_config is not None:
self.mem0_client = Mem0(api_key=self.memory_config.api_key)
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -365,11 +377,13 @@ class App(EmbedChain):
app_config_data = config_data.get("app", {}).get("config", {})
vector_db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
memory_config_data = config_data.get("memory", {})
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
cache_config_data = config_data.get("cache", None)
app_config = AppConfig(**app_config_data)
memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
vector_db_provider = vector_db_config_data.get("provider", "chroma")
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
@@ -403,6 +417,7 @@ class App(EmbedChain):
auto_deploy=auto_deploy,
chunker=chunker_config_data,
cache_config=cache_config,
memory_config=memory_config,
)
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):