Integrate Mem0 (#1462)
Co-authored-by: Deshraj Yadav <deshraj@gatech.edu>
This commit is contained in:
@@ -5,9 +5,12 @@ from typing import Any, Optional
|
||||
from langchain.schema import BaseMessage as LCBaseMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.config.llm.base import (
|
||||
DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE,
|
||||
)
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
@@ -74,6 +77,16 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
return "\n".join(self.history)
|
||||
|
||||
def _format_memories(self, memories: list[dict]) -> str:
|
||||
"""Format memories to be used in prompt
|
||||
|
||||
:param memories: Memories to format
|
||||
:type memories: list[dict]
|
||||
:return: Formatted memories
|
||||
:rtype: str
|
||||
"""
|
||||
return "\n".join([memory["text"] for memory in memories])
|
||||
|
||||
def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
@@ -88,6 +101,7 @@ class BaseLlm(JSONSerializable):
|
||||
"""
|
||||
context_string = " | ".join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
memories = kwargs.get("memories", None)
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
|
||||
@@ -103,10 +117,19 @@ class BaseLlm(JSONSerializable):
|
||||
not self.config._validate_prompt_history(self.config.prompt)
|
||||
and self.config.prompt.template == DEFAULT_PROMPT
|
||||
):
|
||||
# swap in the template with history
|
||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||
context=context_string, query=input_query, history=self._format_history()
|
||||
)
|
||||
if memories:
|
||||
# swap in the template with Mem0 memory template
|
||||
prompt = DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE.substitute(
|
||||
context=context_string,
|
||||
query=input_query,
|
||||
history=self._format_history(),
|
||||
memories=self._format_memories(memories),
|
||||
)
|
||||
else:
|
||||
# swap in the template with history
|
||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||
context=context_string, query=input_query, history=self._format_history()
|
||||
)
|
||||
else:
|
||||
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
|
||||
logger.warning(
|
||||
@@ -180,7 +203,7 @@ class BaseLlm(JSONSerializable):
|
||||
if token_info:
|
||||
logger.info(f"Token Info: {token_info}")
|
||||
|
||||
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, memories=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -216,6 +239,7 @@ class BaseLlm(JSONSerializable):
|
||||
k = {}
|
||||
if self.config.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
k["memories"] = memories
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logger.info(f"Prompt: {prompt}")
|
||||
if dry_run:
|
||||
|
||||
Reference in New Issue
Block a user