Improvements to Graph Memory (#1779)

This commit is contained in:
Prateek Chhikara
2024-08-29 22:17:08 -07:00
committed by GitHub
parent 28bc4fe05b
commit 822a8acedb
10 changed files with 246 additions and 79 deletions

View File

@@ -1,51 +1,29 @@
from langchain_community.graphs import Neo4jGraph
from pydantic import BaseModel, Field
import json
from openai import OpenAI
from mem0.embeddings.openai import OpenAIEmbedding
from mem0.llms.openai import OpenAILLM
from rank_bm25 import BM25Okapi
from mem0.utils.factory import LlmFactory, EmbedderFactory
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
client = OpenAI()
class GraphData(BaseModel):
source: str = Field(..., description="The source node of the relationship")
target: str = Field(..., description="The target node of the relationship")
relationship: str = Field(..., description="The type of the relationship")
class Entities(BaseModel):
source_node: str
source_type: str
relation: str
destination_node: str
destination_type: str
class ADDQuery(BaseModel):
entities: list[Entities]
class SEARCHQuery(BaseModel):
nodes: list[str]
relations: list[str]
def get_embedding(text):
response = client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
class MemoryGraph:
def __init__(self, config):
self.config = config
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config
)
self.llm = OpenAILLM()
self.embedding_model = OpenAIEmbedding()
if self.config.llm.provider:
llm_provider = self.config.llm.provider
if self.config.graph_store.llm:
llm_provider = self.config.graph_store.llm.provider
else:
llm_provider = "openai_structured"
self.llm = LlmFactory.create(llm_provider, self.config.llm.config)
self.user_id = None
self.threshold = 0.7
self.model_name = "gpt-4o-2024-08-06"
def add(self, data):
"""
@@ -61,41 +39,45 @@ class MemoryGraph:
# retrieve the search results
search_output = self._search(data)
extracted_entities = client.beta.chat.completions.parse(
model=self.model_name,
if self.config.graph_store.custom_prompt:
messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
{"role": "user", "content": data},
]
else:
messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data},
],
response_format=ADDQuery,
temperature=0,
).choices[0].message.parsed.entities
]
extracted_entities = self.llm.generate_response(
messages=messages,
tools = [ADD_MESSAGE_TOOL],
)
if extracted_entities['tool_calls']:
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
else:
extracted_entities = []
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
memory_updates = client.beta.chat.completions.parse(
model=self.model_name,
memory_updates = self.llm.generate_response(
messages=update_memory_prompt,
tools=tools,
temperature=0,
).choices[0].message.tool_calls
tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL],
)
to_be_added = []
for item in memory_updates:
function_name = item.function.name
arguments = json.loads(item.function.arguments)
if function_name == "add_graph_memory":
to_be_added.append(arguments)
elif function_name == "update_graph_memory":
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship'])
elif function_name == "update_name":
self._update_name(arguments['name'])
elif function_name == "noop":
for item in memory_updates['tool_calls']:
if item['name'] == "add_graph_memory":
to_be_added.append(item['arguments'])
elif item['name'] == "update_graph_memory":
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
elif item['name'] == "noop":
continue
new_relationships_response = []
for item in to_be_added:
source = item['source'].lower().replace(" ", "_")
source_type = item['source_type'].lower().replace(" ", "_")
@@ -104,8 +86,8 @@ class MemoryGraph:
destination_type = item['destination_type'].lower().replace(" ", "_")
# Create embeddings
source_embedding = get_embedding(source)
dest_embedding = get_embedding(destination)
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# Updated Cypher query to include node types and embeddings
cypher = f"""
@@ -127,22 +109,28 @@ class MemoryGraph:
"dest_embedding": dest_embedding
}
result = self.graph.query(cypher, params=params)
_ = self.graph.query(cypher, params=params)
def _search(self, query):
search_results = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
search_results = self.llm.generate_response(
messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
{"role": "user", "content": query},
],
response_format=SEARCHQuery,
).choices[0].message
node_list = search_results.parsed.nodes
relation_list = search_results.parsed.relations
tools = [SEARCH_TOOL]
)
node_list = []
relation_list = []
for item in search_results['tool_calls']:
if item['name'] == "search":
node_list.extend(item['arguments']['nodes'])
relation_list.extend(item['arguments']['relations'])
node_list = list(set(node_list))
relation_list = list(set(relation_list))
node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
@@ -150,7 +138,7 @@ class MemoryGraph:
result_relations = []
for node in node_list:
n_embedding = get_embedding(node)
n_embedding = self.embedding_model.embed(node)
cypher_query = """
MATCH (n)
@@ -195,12 +183,22 @@ class MemoryGraph:
"""
search_output = self._search(query)
if not search_output:
return []
search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
search_results = []
for item in search_output:
for item in reranked_results:
search_results.append({
"source": item['source'],
"relation": item['relation'],
"destination": item['destination']
"source": item[0],
"relation": item[1],
"destination": item[2]
})
return search_results