[improvement]: Graph memory support for non-structured models. (#1823)

This commit is contained in:
Mayank
2024-09-08 01:56:43 +05:30
committed by GitHub
parent a972d2fb07
commit 51c4f2aae8
5 changed files with 265 additions and 14 deletions

View File

@@ -9,6 +9,11 @@ from mem0.graphs.tools import (
NOOP_TOOL,
SEARCH_TOOL,
UPDATE_MEMORY_TOOL_GRAPH,
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
ADD_MEMORY_STRUCT_TOOL_GRAPH,
NOOP_STRUCT_TOOL,
ADD_MESSAGE_STRUCT_TOOL,
SEARCH_STRUCT_TOOL
)
from mem0.graphs.utils import EXTRACT_ENTITIES_PROMPT, get_update_memory_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
@@ -23,14 +28,13 @@ class MemoryGraph:
self.config.embedder.provider, self.config.embedder.config
)
self.llm_provider = "openai_structured"
if self.config.llm.provider:
llm_provider = self.config.llm.provider
self.llm_provider = self.config.llm.provider
if self.config.graph_store.llm:
llm_provider = self.config.graph_store.llm.provider
else:
llm_provider = "openai_structured"
self.llm_provider = self.config.graph_store.llm.provider
self.llm = LlmFactory.create(llm_provider, self.config.llm.config)
self.llm = LlmFactory.create(self.llm_provider, self.config.llm.config)
self.user_id = None
self.threshold = 0.7
@@ -60,9 +64,13 @@ class MemoryGraph:
{"role": "user", "content": data},
]
_tools = [ADD_MESSAGE_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [ADD_MESSAGE_STRUCT_TOOL]
extracted_entities = self.llm.generate_response(
messages=messages,
tools = [ADD_MESSAGE_TOOL],
tools = _tools,
)
if extracted_entities['tool_calls']:
@@ -74,9 +82,13 @@ class MemoryGraph:
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
_tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
if self.llm_provider in ["azure_openai_structured","openai_structured"]:
_tools = [UPDATE_MEMORY_STRUCT_TOOL_GRAPH, ADD_MEMORY_STRUCT_TOOL_GRAPH, NOOP_STRUCT_TOOL]
memory_updates = self.llm.generate_response(
messages=update_memory_prompt,
tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL],
tools=_tools,
)
to_be_added = []
@@ -127,12 +139,15 @@ class MemoryGraph:
def _search(self, query, filters):
_tools = [SEARCH_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [SEARCH_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
{"role": "user", "content": query},
],
tools = [SEARCH_TOOL]
tools = _tools
)
node_list = []