[graph_memory]: improve delete/add graph memory (#2073)

This commit is contained in:
Mayank
2025-01-03 22:21:05 +05:30
committed by GitHub
parent 542153ad4f
commit 78a2ef41d7
7 changed files with 439 additions and 225 deletions

View File

@@ -13,18 +13,14 @@ except ImportError:
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
from mem0.graphs.tools import (
ADD_MEMORY_STRUCT_TOOL_GRAPH,
ADD_MEMORY_TOOL_GRAPH,
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
NOOP_STRUCT_TOOL,
NOOP_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
UPDATE_MEMORY_TOOL_GRAPH,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_update_memory_messages
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
@@ -58,150 +54,17 @@ class MemoryGraph:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""
# retrieve the search results
search_output, entity_type_map = self._search(data, filters)
# extract relations
extracted_relations = self._extract_relations(data, filters, entity_type_map)
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(data, filters, entity_type_map)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
to_be_deleted = self._get_delete_entities_from_search_output(search_output, data, filters)
search_output_string = format_entities(search_output)
extracted_relations_string = format_entities(extracted_relations)
update_memory_prompt = get_update_memory_messages(search_output_string, extracted_relations_string)
_tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
UPDATE_MEMORY_STRUCT_TOOL_GRAPH,
ADD_MEMORY_STRUCT_TOOL_GRAPH,
NOOP_STRUCT_TOOL,
]
memory_updates = self.llm.generate_response(
messages=update_memory_prompt,
tools=_tools,
)
to_be_added = []
for item in memory_updates["tool_calls"]:
if item["name"] == "add_graph_memory":
to_be_added.append(item["arguments"])
elif item["name"] == "update_graph_memory":
self._update_relationship(
item["arguments"]["source"],
item["arguments"]["destination"],
item["arguments"]["relationship"],
filters,
)
elif item["name"] == "noop":
continue
returned_entities = []
for item in to_be_added:
source = item["source"].lower().replace(" ", "_")
source_type = item["source_type"].lower().replace(" ", "_")
relation = item["relationship"].lower().replace(" ", "_")
destination = item["destination"].lower().replace(" ", "_")
destination_type = item["destination_type"].lower().replace(" ", "_")
returned_entities.append({"source": source, "relationship": relation, "target": destination})
# Create embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# Updated Cypher query to include node types and embeddings
cypher = f"""
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relation}]->(m)
ON CREATE SET rel.created = timestamp()
RETURN n, rel, m
"""
params = {
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": filters["user_id"],
}
_ = self.graph.query(cypher, params=params)
logger.info(f"Added {len(to_be_added)} new memories to the graph")
return returned_entities
def _search(self, query, filters, limit=100):
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": query},
],
tools=_tools,
)
entity_type_map = {}
try:
for item in search_results["tool_calls"][0]["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.error(f"Error in search tool: {e}")
logger.debug(f"Entity type map: {entity_type_map}")
result_relations = []
for node in list(entity_type_map.keys()):
n_embedding = self.embedding_model.embed(node)
cypher_query = """
MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
UNION
MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations, entity_type_map
#TODO: Batch queries with APOC plugin
#TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
added_entities = self._add_entities(to_be_added, filters["user_id"], entity_type_map)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
def search(self, query, filters, limit=100):
"""
@@ -217,13 +80,13 @@ class MemoryGraph:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
search_output, entity_type_map = self._search(query, filters, limit)
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(node_list=list(entity_type_map.keys()), filters=filters)
if not search_output:
return []
search_outputs_sequence = [[item["source"], item["relation"], item["destination"]] for item in search_output]
search_outputs_sequence = [[item["source"], item["relatationship"], item["destination"]] for item in search_output]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
@@ -231,7 +94,7 @@ class MemoryGraph:
search_results = []
for item in reranked_results:
search_results.append({"source": item[0], "relationship": item[1], "target": item[2]})
search_results.append({"source": item[0], "relationship": item[1], "destination": item[2]})
logger.info(f"Returned {len(search_results)} search results")
@@ -279,9 +142,37 @@ class MemoryGraph:
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)
def _extract_relations(self, data, filters, entity_type_map, limit=100):
entity_type_map = {}
try:
for item in search_results["tool_calls"][0]["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.error(f"Error in search tool: {e}")
entity_type_map = {k.lower().replace(" ", "_"): v.lower().replace(" ", "_") for k, v in entity_type_map.items()}
logger.debug(f"Entity type map: {entity_type_map}")
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Eshtablish relations among the extracted nodes."""
if self.config.graph_store.custom_prompt:
messages = [
{
@@ -315,57 +206,292 @@ class MemoryGraph:
else:
extracted_entities = []
extracted_entities = self._remove_spaces_from_entities(extracted_entities)
logger.debug(f"Extracted entities: {extracted_entities}")
return extracted_entities
def _search_graph_db(self, node_list, filters, limit=100):
"""Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = []
def _update_relationship(self, source, target, relationship, filters):
"""
Update or create a relationship between two nodes in the graph.
for node in node_list:
n_embedding = self.embedding_model.embed(node)
Args:
source (str): The name of the source node.
target (str): The name of the target node.
relationship (str): The type of the relationship.
filters (dict): A dictionary containing filters to be applied during the update.
cypher_query = """
MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relatationship, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
UNION
MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relatationship, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
Raises:
Exception: If the operation fails.
"""
logger.info(f"Updating relationship: {source} -{relationship}-> {target}")
return result_relations
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages(search_output_string, data, filters["user_id"])
relationship = relationship.lower().replace(" ", "_")
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
]
# Check if nodes exist and create them if they don't
check_and_create_query = """
MERGE (n1 {name: $source, user_id: $user_id})
MERGE (n2 {name: $target, user_id: $user_id})
"""
self.graph.query(
check_and_create_query,
params={"source": source, "target": target, "user_id": filters["user_id"]},
memory_updates = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=_tools,
)
to_be_deleted = []
for item in memory_updates["tool_calls"]:
if item["name"] == "delete_graph_memory":
to_be_deleted.append(item["arguments"])
#in case if it is not in the correct format
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, user_id):
"""Delete the entities from the graph."""
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relatationship = item["relationship"]
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n {{name: $source_name, user_id: $user_id}})
-[r:{relatationship}]->
(m {{name: $dest_name, user_id: $user_id}})
DELETE r
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
"""
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
}
result = self.graph.query(cypher, params=params)
results.append(result)
return results
def _add_entities(self, to_be_added, user_id, entity_type_map):
"""Add the new entities to the graph. Merge the nodes if they already exist."""
results = []
for item in to_be_added:
#entities
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# Delete any existing relationship between the nodes
delete_query = """
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
DELETE r
"""
self.graph.query(
delete_query,
params={"source": source, "target": target, "user_id": filters["user_id"]},
)
#types
source_type = entity_type_map.get(source, "unknown")
destination_type = entity_type_map.get(destination, "unknown")
#embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
#search for the nodes with the closest embeddings
source_node_search_result = self._search_source_node(source_embedding, user_id, threshold=0.9)
destination_node_search_result = self._search_destination_node(dest_embedding, user_id, threshold=0.9)
# Create the new relationship
create_query = f"""
MATCH (n1 {{name: $source, user_id: $user_id}}), (n2 {{name: $target, user_id: $user_id}})
CREATE (n1)-[r:{relationship}]->(n2)
RETURN n1, r, n2
"""
result = self.graph.query(
create_query,
params={"source": source, "target": target, "user_id": filters["user_id"]},
)
#TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
cypher = f"""
MATCH (source)
WHERE elementId(source) = $source_id
MERGE (destination:{destination_type} {{name: $destination_name, user_id: $user_id}})
ON CREATE SET
destination.created = timestamp(),
destination.embedding = $destination_embedding
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
if not result:
raise Exception(f"Failed to update or create relationship between {source} and {target}")
params = {
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
"destination_name": destination,
"relationship": relationship,
"destination_type": destination_type,
"destination_embedding": dest_embedding,
"user_id": user_id,
}
resp = self.graph.query(cypher, params=params)
results.append(resp)
elif destination_node_search_result and not source_node_search_result:
cypher = f"""
MATCH (destination)
WHERE elementId(destination) = $destination_id
MERGE (source:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET
source.created = timestamp(),
source.embedding = $source_embedding
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
"source_name": source,
"relationship": relationship,
"source_type": source_type,
"source_embedding": source_embedding,
"user_id": user_id,
}
resp = self.graph.query(cypher, params=params)
results.append(resp)
elif source_node_search_result and destination_node_search_result:
cypher = f"""
MATCH (source)
WHERE elementId(source) = $source_id
MATCH (destination)
WHERE elementId(destination) = $destination_id
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created_at = timestamp(),
r.updated_at = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_search_result[0]['elementId(source_candidate)'],
"destination_id": destination_node_search_result[0]['elementId(destination_candidate)'],
"user_id": user_id,
"relationship": relationship,
}
resp = self.graph.query(cypher, params=params)
results.append(resp)
elif not source_node_search_result and not destination_node_search_result:
cypher = f"""
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relationship}]->(m)
ON CREATE SET rel.created = timestamp()
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
"""
params = {
"source_name": source,
"source_type": source_type,
"dest_name": destination,
"destination_type": destination_type,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
resp = self.graph.query(cypher, params=params)
results.append(resp)
return results
def _remove_spaces_from_entities(self, entity_list):
for item in entity_list:
item["source"] = item["source"].lower().replace(" ", "_")
item["relationship"] = item["relationship"].lower().replace(" ", "_")
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher = f"""
MATCH (source_candidate)
WHERE source_candidate.embedding IS NOT NULL
AND source_candidate.user_id = $user_id
WITH source_candidate,
round(
reduce(dot = 0.0, i IN range(0, size(source_candidate.embedding)-1) |
dot + source_candidate.embedding[i] * $source_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(source_candidate.embedding)-1) |
l2 + source_candidate.embedding[i] * source_candidate.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($source_embedding)-1) |
l2 + $source_embedding[i] * $source_embedding[i])))
, 4) AS source_similarity
WHERE source_similarity >= $threshold
WITH source_candidate, source_similarity
ORDER BY source_similarity DESC
LIMIT 1
RETURN elementId(source_candidate)
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher = f"""
MATCH (destination_candidate)
WHERE destination_candidate.embedding IS NOT NULL
AND destination_candidate.user_id = $user_id
WITH destination_candidate,
round(
reduce(dot = 0.0, i IN range(0, size(destination_candidate.embedding)-1) |
dot + destination_candidate.embedding[i] * $destination_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(destination_candidate.embedding)-1) |
l2 + destination_candidate.embedding[i] * destination_candidate.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($destination_embedding)-1) |
l2 + $destination_embedding[i] * $destination_embedding[i])))
, 4) AS destination_similarity
WHERE destination_similarity >= $threshold
WITH destination_candidate, destination_similarity
ORDER BY destination_similarity DESC
LIMIT 1
RETURN elementId(destination_candidate)
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result