Integrate Mem0 (#1462)

Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
Dev Khant
2024-07-07 00:57:01 +05:30
committed by GitHub
parent bd654e7aac
commit bbe56107fb
11 changed files with 195 additions and 34 deletions

View File

@@ -5,9 +5,12 @@ from typing import Any, Optional
from langchain.schema import BaseMessage as LCBaseMessage
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.config.llm.base import (
DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE,
)
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
@@ -74,6 +77,16 @@ class BaseLlm(JSONSerializable):
"""
return "\n".join(self.history)
def _format_memories(self, memories: list[dict]) -> str:
"""Format memories to be used in prompt
:param memories: Memories to format
:type memories: list[dict]
:return: Formatted memories
:rtype: str
"""
return "\n".join([memory["text"] for memory in memories])
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
"""
Generates a prompt based on the given query and context, ready to be
@@ -88,6 +101,7 @@ class BaseLlm(JSONSerializable):
"""
context_string = " | ".join(contexts)
web_search_result = kwargs.get("web_search_result", "")
memories = kwargs.get("memories", None)
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
@@ -103,10 +117,19 @@ class BaseLlm(JSONSerializable):
not self.config._validate_prompt_history(self.config.prompt)
and self.config.prompt.template == DEFAULT_PROMPT
):
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
context=context_string, query=input_query, history=self._format_history()
)
if memories:
# swap in the template with Mem0 memory template
prompt = DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE.substitute(
context=context_string,
query=input_query,
history=self._format_history(),
memories=self._format_memories(memories),
)
else:
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
context=context_string, query=input_query, history=self._format_history()
)
else:
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
logger.warning(
@@ -180,7 +203,7 @@ class BaseLlm(JSONSerializable):
if token_info:
logger.info(f"Token Info: {token_info}")
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, memories=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -216,6 +239,7 @@ class BaseLlm(JSONSerializable):
k = {}
if self.config.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
k["memories"] = memories
prompt = self.generate_prompt(input_query, contexts, **k)
logger.info(f"Prompt: {prompt}")
if dry_run: