Improvements to Graph Memory (#1779)
This commit is contained in:
@@ -1,51 +1,29 @@
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.embeddings.openai import OpenAIEmbedding
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
from rank_bm25 import BM25Okapi
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory
|
||||
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL, ADD_MESSAGE_TOOL, SEARCH_TOOL
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
class GraphData(BaseModel):
|
||||
source: str = Field(..., description="The source node of the relationship")
|
||||
target: str = Field(..., description="The target node of the relationship")
|
||||
relationship: str = Field(..., description="The type of the relationship")
|
||||
|
||||
class Entities(BaseModel):
|
||||
source_node: str
|
||||
source_type: str
|
||||
relation: str
|
||||
destination_node: str
|
||||
destination_type: str
|
||||
|
||||
class ADDQuery(BaseModel):
|
||||
entities: list[Entities]
|
||||
|
||||
class SEARCHQuery(BaseModel):
|
||||
nodes: list[str]
|
||||
relations: list[str]
|
||||
|
||||
def get_embedding(text):
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider, self.config.embedder.config
|
||||
)
|
||||
|
||||
self.llm = OpenAILLM()
|
||||
self.embedding_model = OpenAIEmbedding()
|
||||
if self.config.llm.provider:
|
||||
llm_provider = self.config.llm.provider
|
||||
if self.config.graph_store.llm:
|
||||
llm_provider = self.config.graph_store.llm.provider
|
||||
else:
|
||||
llm_provider = "openai_structured"
|
||||
|
||||
self.llm = LlmFactory.create(llm_provider, self.config.llm.config)
|
||||
self.user_id = None
|
||||
self.threshold = 0.7
|
||||
self.model_name = "gpt-4o-2024-08-06"
|
||||
|
||||
def add(self, data):
|
||||
"""
|
||||
@@ -61,41 +39,45 @@ class MemoryGraph:
|
||||
|
||||
# retrieve the search results
|
||||
search_output = self._search(data)
|
||||
|
||||
extracted_entities = client.beta.chat.completions.parse(
|
||||
model=self.model_name,
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id).replace("CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}")},
|
||||
{"role": "user", "content": data},
|
||||
]
|
||||
else:
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
{"role": "user", "content": data},
|
||||
],
|
||||
response_format=ADDQuery,
|
||||
temperature=0,
|
||||
).choices[0].message.parsed.entities
|
||||
]
|
||||
|
||||
extracted_entities = self.llm.generate_response(
|
||||
messages=messages,
|
||||
tools = [ADD_MESSAGE_TOOL],
|
||||
)
|
||||
|
||||
if extracted_entities['tool_calls']:
|
||||
extracted_entities = extracted_entities['tool_calls'][0]['arguments']['entities']
|
||||
else:
|
||||
extracted_entities = []
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
|
||||
memory_updates = client.beta.chat.completions.parse(
|
||||
model=self.model_name,
|
||||
memory_updates = self.llm.generate_response(
|
||||
messages=update_memory_prompt,
|
||||
tools=tools,
|
||||
temperature=0,
|
||||
).choices[0].message.tool_calls
|
||||
tools=[UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL],
|
||||
)
|
||||
|
||||
to_be_added = []
|
||||
for item in memory_updates:
|
||||
function_name = item.function.name
|
||||
arguments = json.loads(item.function.arguments)
|
||||
if function_name == "add_graph_memory":
|
||||
to_be_added.append(arguments)
|
||||
elif function_name == "update_graph_memory":
|
||||
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship'])
|
||||
elif function_name == "update_name":
|
||||
self._update_name(arguments['name'])
|
||||
elif function_name == "noop":
|
||||
|
||||
for item in memory_updates['tool_calls']:
|
||||
if item['name'] == "add_graph_memory":
|
||||
to_be_added.append(item['arguments'])
|
||||
elif item['name'] == "update_graph_memory":
|
||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
|
||||
elif item['name'] == "noop":
|
||||
continue
|
||||
|
||||
new_relationships_response = []
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
@@ -104,8 +86,8 @@ class MemoryGraph:
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = get_embedding(source)
|
||||
dest_embedding = get_embedding(destination)
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
|
||||
# Updated Cypher query to include node types and embeddings
|
||||
cypher = f"""
|
||||
@@ -127,22 +109,28 @@ class MemoryGraph:
|
||||
"dest_embedding": dest_embedding
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def _search(self, query):
|
||||
search_results = client.beta.chat.completions.parse(
|
||||
model="gpt-4o-2024-08-06",
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
response_format=SEARCHQuery,
|
||||
).choices[0].message
|
||||
|
||||
node_list = search_results.parsed.nodes
|
||||
relation_list = search_results.parsed.relations
|
||||
tools = [SEARCH_TOOL]
|
||||
)
|
||||
|
||||
node_list = []
|
||||
relation_list = []
|
||||
|
||||
for item in search_results['tool_calls']:
|
||||
if item['name'] == "search":
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
relation_list.extend(item['arguments']['relations'])
|
||||
|
||||
node_list = list(set(node_list))
|
||||
relation_list = list(set(relation_list))
|
||||
|
||||
node_list = [node.lower().replace(" ", "_") for node in node_list]
|
||||
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
|
||||
@@ -150,7 +138,7 @@ class MemoryGraph:
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
n_embedding = get_embedding(node)
|
||||
n_embedding = self.embedding_model.embed(node)
|
||||
|
||||
cypher_query = """
|
||||
MATCH (n)
|
||||
@@ -195,12 +183,22 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
search_output = self._search(query)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
|
||||
search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output]
|
||||
bm25 = BM25Okapi(search_outputs_sequence)
|
||||
|
||||
tokenized_query = query.split(" ")
|
||||
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
|
||||
|
||||
search_results = []
|
||||
for item in search_output:
|
||||
for item in reranked_results:
|
||||
search_results.append({
|
||||
"source": item['source'],
|
||||
"relation": item['relation'],
|
||||
"destination": item['destination']
|
||||
"source": item[0],
|
||||
"relation": item[1],
|
||||
"destination": item[2]
|
||||
})
|
||||
|
||||
return search_results
|
||||
|
||||
Reference in New Issue
Block a user