Integrate Mem0 (#1462)
Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
@@ -9,19 +9,24 @@ import requests
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.cache import (Config, ExactMatchEvaluation,
|
||||
SearchDistanceEvaluation, cache,
|
||||
gptcache_data_manager, gptcache_pre_function)
|
||||
from mem0 import Mem0
|
||||
from embedchain.cache import (
|
||||
Config,
|
||||
ExactMatchEvaluation,
|
||||
SearchDistanceEvaluation,
|
||||
cache,
|
||||
gptcache_data_manager,
|
||||
gptcache_pre_function,
|
||||
)
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
|
||||
from embedchain.core.db.database import get_session, init_db, setup_engine
|
||||
from embedchain.core.db.models import DataSource
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.evaluation.base import BaseMetric
|
||||
from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
|
||||
Groundedness)
|
||||
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
@@ -55,6 +60,7 @@ class App(EmbedChain):
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
cache_config: CacheConfig = None,
|
||||
memory_config: Mem0Config = None,
|
||||
log_level: int = logging.WARN,
|
||||
):
|
||||
"""
|
||||
@@ -95,6 +101,7 @@ class App(EmbedChain):
|
||||
self.id = None
|
||||
self.chunker = ChunkerConfig(**chunker) if chunker else None
|
||||
self.cache_config = cache_config
|
||||
self.memory_config = memory_config
|
||||
|
||||
self.config = config or AppConfig()
|
||||
self.name = self.config.name
|
||||
@@ -123,6 +130,11 @@ class App(EmbedChain):
|
||||
if self.cache_config is not None:
|
||||
self._init_cache()
|
||||
|
||||
# If memory_config is provided, initializing the memory ...
|
||||
self.mem0_client = None
|
||||
if self.memory_config is not None:
|
||||
self.mem0_client = Mem0(api_key=self.memory_config.api_key)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
@@ -365,11 +377,13 @@ class App(EmbedChain):
|
||||
app_config_data = config_data.get("app", {}).get("config", {})
|
||||
vector_db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
memory_config_data = config_data.get("memory", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
cache_config_data = config_data.get("cache", None)
|
||||
|
||||
app_config = AppConfig(**app_config_data)
|
||||
memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
|
||||
|
||||
vector_db_provider = vector_db_config_data.get("provider", "chroma")
|
||||
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
|
||||
@@ -403,6 +417,7 @@ class App(EmbedChain):
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
cache_config=cache_config,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
|
||||
|
||||
Reference in New Issue
Block a user