Integrate Mem0 (#1462)
Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
@@ -9,19 +9,24 @@ import requests
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.cache import (Config, ExactMatchEvaluation,
|
||||
SearchDistanceEvaluation, cache,
|
||||
gptcache_data_manager, gptcache_pre_function)
|
||||
from mem0 import Mem0
|
||||
from embedchain.cache import (
|
||||
Config,
|
||||
ExactMatchEvaluation,
|
||||
SearchDistanceEvaluation,
|
||||
cache,
|
||||
gptcache_data_manager,
|
||||
gptcache_pre_function,
|
||||
)
|
||||
from embedchain.client import Client
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
|
||||
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
|
||||
from embedchain.core.db.database import get_session, init_db, setup_engine
|
||||
from embedchain.core.db.models import DataSource
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.evaluation.base import BaseMetric
|
||||
from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
|
||||
Groundedness)
|
||||
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
@@ -55,6 +60,7 @@ class App(EmbedChain):
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
cache_config: CacheConfig = None,
|
||||
memory_config: Mem0Config = None,
|
||||
log_level: int = logging.WARN,
|
||||
):
|
||||
"""
|
||||
@@ -95,6 +101,7 @@ class App(EmbedChain):
|
||||
self.id = None
|
||||
self.chunker = ChunkerConfig(**chunker) if chunker else None
|
||||
self.cache_config = cache_config
|
||||
self.memory_config = memory_config
|
||||
|
||||
self.config = config or AppConfig()
|
||||
self.name = self.config.name
|
||||
@@ -123,6 +130,11 @@ class App(EmbedChain):
|
||||
if self.cache_config is not None:
|
||||
self._init_cache()
|
||||
|
||||
# If memory_config is provided, initializing the memory ...
|
||||
self.mem0_client = None
|
||||
if self.memory_config is not None:
|
||||
self.mem0_client = Mem0(api_key=self.memory_config.api_key)
|
||||
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
@@ -365,11 +377,13 @@ class App(EmbedChain):
|
||||
app_config_data = config_data.get("app", {}).get("config", {})
|
||||
vector_db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
memory_config_data = config_data.get("memory", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
cache_config_data = config_data.get("cache", None)
|
||||
|
||||
app_config = AppConfig(**app_config_data)
|
||||
memory_config = Mem0Config(**memory_config_data) if memory_config_data else None
|
||||
|
||||
vector_db_provider = vector_db_config_data.get("provider", "chroma")
|
||||
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))
|
||||
@@ -403,6 +417,7 @@ class App(EmbedChain):
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
cache_config=cache_config,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
def _eval(self, dataset: list[EvalData], metric: Union[BaseMetric, str]):
|
||||
|
||||
@@ -12,3 +12,4 @@ from .vectordb.chroma import ChromaDbConfig
|
||||
from .vectordb.elasticsearch import ElasticsearchDBConfig
|
||||
from .vectordb.opensearch import OpenSearchDBConfig
|
||||
from .vectordb.zilliz import ZillizDBConfig
|
||||
from .mem0_config import Mem0Config
|
||||
|
||||
@@ -50,6 +50,35 @@ Query: $query
|
||||
Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY = """
|
||||
You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history and memories with the user. Make sure to use relevant context from conversation history and memories as needed.
|
||||
|
||||
Here are some guidelines to follow:
|
||||
|
||||
1. Refrain from explicitly mentioning the context provided in your response.
|
||||
2. Take into consideration the conversation history and memories provided.
|
||||
3. The context should silently guide your answers without being directly acknowledged.
|
||||
4. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
|
||||
|
||||
Context information:
|
||||
----------------------
|
||||
$context
|
||||
----------------------
|
||||
|
||||
Conversation history:
|
||||
----------------------
|
||||
$history
|
||||
----------------------
|
||||
|
||||
Memories/Preferences:
|
||||
----------------------
|
||||
$memories
|
||||
----------------------
|
||||
|
||||
Query: $query
|
||||
Answer:
|
||||
""" # noqa:E501
|
||||
|
||||
DOCS_SITE_DEFAULT_PROMPT = """
|
||||
You are an expert AI assistant for developer support product. Your responses must always be rooted in the context provided for each query. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
|
||||
|
||||
@@ -70,6 +99,7 @@ Answer:
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_MEM0_MEMORY)
|
||||
DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
|
||||
query_re = re.compile(r"\$\{*query\}*")
|
||||
context_re = re.compile(r"\$\{*context\}*")
|
||||
|
||||
21
embedchain/config/mem0_config.py
Normal file
21
embedchain/config/mem0_config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Mem0Config(BaseConfig):
|
||||
def __init__(self, api_key: str, top_k: Optional[int] = 10):
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return Mem0Config()
|
||||
else:
|
||||
return Mem0Config(
|
||||
api_key=config.get("api_key", ""),
|
||||
init_config=config.get("top_k", 10),
|
||||
)
|
||||
@@ -52,6 +52,8 @@ class EmbedChain(JSONSerializable):
|
||||
"""
|
||||
self.config = config
|
||||
self.cache_config = None
|
||||
self.memory_config = None
|
||||
self.mem0_client = None
|
||||
# Llm
|
||||
self.llm = llm
|
||||
# Database has support for config assignment for backwards compatibility
|
||||
@@ -595,6 +597,12 @@ class EmbedChain(JSONSerializable):
|
||||
else:
|
||||
contexts_data_for_llm_query = contexts
|
||||
|
||||
memories = None
|
||||
if self.mem0_client:
|
||||
memories = self.mem0_client.search(
|
||||
query=input_query, agent_id=self.config.id, session_id=session_id, limit=self.memory_config.top_k
|
||||
)
|
||||
|
||||
# Update the history beforehand so that we can handle multiple chat sessions in the same python session
|
||||
self.llm.update_history(app_id=self.config.id, session_id=session_id)
|
||||
|
||||
@@ -615,13 +623,27 @@ class EmbedChain(JSONSerializable):
|
||||
logger.debug("Cache disabled. Running chat without cache.")
|
||||
if self.llm.config.token_usage:
|
||||
answer, token_info = self.llm.query(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
memories=memories,
|
||||
)
|
||||
else:
|
||||
answer = self.llm.query(
|
||||
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
|
||||
input_query=input_query,
|
||||
contexts=contexts_data_for_llm_query,
|
||||
config=config,
|
||||
dry_run=dry_run,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
# Add to Mem0 memory if enabled
|
||||
# TODO: Might need to prepend with some text like:
|
||||
# "Remember user preferences from following user query: {input_query}"
|
||||
if self.mem0_client:
|
||||
self.mem0_client.add(data=input_query, agent_id=self.config.id, session_id=session_id)
|
||||
|
||||
# add conversation in memory
|
||||
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
|
||||
|
||||
|
||||
@@ -5,9 +5,12 @@ from typing import Any, Optional
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.config.llm.base import (
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE,
|
||||
)
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
@@ -74,6 +77,16 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
return "\n".join(self.history)
|
||||
|
||||
def _format_memories(self, memories: list[dict]) -> str:
|
||||
"""Format memories to be used in prompt
|
||||
|
||||
:param memories: Memories to format
|
||||
:type memories: list[dict]
|
||||
:return: Formatted memories
|
||||
:rtype: str
|
||||
"""
|
||||
return "\n".join([memory["text"] for memory in memories])
|
||||
|
||||
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
@@ -88,6 +101,7 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
context_string = " | ".join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
memories = kwargs.get("memories", None)
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
|
||||
@@ -103,10 +117,19 @@ class BaseLlm(JSONSerializable):
|
||||
not self.config._validate_prompt_history(self.config.prompt)
|
||||
and self.config.prompt.template == DEFAULT_PROMPT
|
||||
):
|
||||
# swap in the template with history
|
||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||
context=context_string, query=input_query, history=self._format_history()
|
||||
)
|
||||
if memories:
|
||||
# swap in the template with Mem0 memory template
|
||||
prompt = DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE.substitute(
|
||||
context=context_string,
|
||||
query=input_query,
|
||||
history=self._format_history(),
|
||||
memories=self._format_memories(memories),
|
||||
)
|
||||
else:
|
||||
# swap in the template with history
|
||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||
context=context_string, query=input_query, history=self._format_history()
|
||||
)
|
||||
else:
|
||||
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
|
||||
logger.warning(
|
||||
@@ -180,7 +203,7 @@ class BaseLlm(JSONSerializable):
|
||||
if token_info:
|
||||
logger.info(f"Token Info: {token_info}")
|
||||
|
||||
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, memories=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -216,6 +239,7 @@ class BaseLlm(JSONSerializable):
|
||||
k = {}
|
||||
if self.config.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
k["memories"] = memories
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logger.info(f"Prompt: {prompt}")
|
||||
if dry_run:
|
||||
|
||||
@@ -520,6 +520,10 @@ def validate_config(config_data):
|
||||
Optional("auto_flush"): int,
|
||||
},
|
||||
},
|
||||
Optional("memory"): {
|
||||
"api_key": str,
|
||||
Optional("top_k"): int,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user