Files
t6_mem0/mem0/memory/memgraph_memory.py
2025-05-06 01:16:02 +05:30

517 lines
21 KiB
Python

import logging
from mem0.memory.utils import format_entities
try:
from langchain_memgraph import Memgraph
except ImportError:
raise ImportError(
"langchain_memgraph is not installed. Please install it using pip install langchain-memgraph"
)
try:
from rank_bm25 import BM25Okapi
except ImportError:
raise ImportError(
"rank_bm25 is not installed. Please install it using pip install rank-bm25"
)
from mem0.graphs.tools import (
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
DELETE_MEMORY_TOOL_GRAPH,
EXTRACT_ENTITIES_STRUCT_TOOL,
EXTRACT_ENTITIES_TOOL,
RELATIONS_STRUCT_TOOL,
RELATIONS_TOOL,
)
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory
logger = logging.getLogger(__name__)
class MemoryGraph:
def __init__(self, config):
self.config = config
self.graph = Memgraph(
self.config.graph_store.config.url,
self.config.graph_store.config.username,
self.config.graph_store.config.password,
)
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
{"enable_embeddings": True},
)
self.llm_provider = "openai_structured"
if self.config.llm.provider:
self.llm_provider = self.config.llm.provider
if self.config.graph_store.llm:
self.llm_provider = self.config.graph_store.llm.provider
self.llm = LlmFactory.create(self.llm_provider, self.config.llm.config)
self.user_id = None
self.threshold = 0.7
# Setup Memgraph:
# 1. Create vector index (created Entity label on all nodes)
# 2. Create label property index for performance optimizations
embedding_dims = self.config.embedder.config["embedding_dims"]
create_vector_index_query = f"CREATE VECTOR INDEX memzero ON :Entity(embedding) WITH CONFIG {{'dimension': {embedding_dims}, 'capacity': 1000, 'metric': 'cos'}};"
self.graph.query(create_vector_index_query, params={})
create_label_prop_index_query = "CREATE INDEX ON :Entity(user_id);"
self.graph.query(create_label_prop_index_query, params={})
create_label_index_query = "CREATE INDEX ON :Entity;"
self.graph.query(create_label_index_query, params={})
def add(self, data, filters):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
filters (dict): A dictionary containing filters to be applied during the addition.
"""
entity_type_map = self._retrieve_nodes_from_data(data, filters)
to_be_added = self._establish_nodes_relations_from_data(
data, filters, entity_type_map
)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
to_be_deleted = self._get_delete_entities_from_search_output(
search_output, data, filters
)
# TODO: Batch queries with APOC plugin
# TODO: Add more filter support
deleted_entities = self._delete_entities(to_be_deleted, filters["user_id"])
added_entities = self._add_entities(
to_be_added, filters["user_id"], entity_type_map
)
return {"deleted_entities": deleted_entities, "added_entities": added_entities}
def search(self, query, filters, limit=100):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
filters (dict): A dictionary containing filters to be applied during the search.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
entity_type_map = self._retrieve_nodes_from_data(query, filters)
search_output = self._search_graph_db(
node_list=list(entity_type_map.keys()), filters=filters
)
if not search_output:
return []
search_outputs_sequence = [
[item["source"], item["relationship"], item["destination"]]
for item in search_output
]
bm25 = BM25Okapi(search_outputs_sequence)
tokenized_query = query.split(" ")
reranked_results = bm25.get_top_n(tokenized_query, search_outputs_sequence, n=5)
search_results = []
for item in reranked_results:
search_results.append(
{"source": item[0], "relationship": item[1], "destination": item[2]}
)
logger.info(f"Returned {len(search_results)} search results")
return search_results
def delete_all(self, filters):
cypher = """
MATCH (n {user_id: $user_id})
DETACH DELETE n
"""
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self, filters, limit=100):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
filters (dict): A dictionary containing filters to be applied during the retrieval.
limit (int): The maximum number of nodes and relationships to retrieve. Defaults to 100.
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
# return all nodes and relationships
query = """
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
LIMIT $limit
"""
results = self.graph.query(
query, params={"user_id": filters["user_id"], "limit": limit}
)
final_results = []
for result in results:
final_results.append(
{
"source": result["source"],
"relationship": result["relationship"],
"target": result["target"],
}
)
logger.info(f"Retrieved {len(final_results)} relationships")
return final_results
def _retrieve_nodes_from_data(self, data, filters):
"""Extracts all the entities mentioned in the query."""
_tools = [EXTRACT_ENTITIES_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [EXTRACT_ENTITIES_STRUCT_TOOL]
search_results = self.llm.generate_response(
messages=[
{
"role": "system",
"content": f"You are a smart assistant who understands entities and their types in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source entity. Extract all the entities from the text. ***DO NOT*** answer the question itself if the given text is a question.",
},
{"role": "user", "content": data},
],
tools=_tools,
)
entity_type_map = {}
try:
for tool_call in search_results["tool_calls"]:
if tool_call["name"] != "extract_entities":
continue
for item in tool_call["arguments"]["entities"]:
entity_type_map[item["entity"]] = item["entity_type"]
except Exception as e:
logger.exception(
f"Error in search tool: {e}, llm_provider={self.llm_provider}, search_results={search_results}"
)
entity_type_map = {
k.lower().replace(" ", "_"): v.lower().replace(" ", "_")
for k, v in entity_type_map.items()
}
logger.debug(
f"Entity type map: {entity_type_map}\n search_results={search_results}"
)
return entity_type_map
def _establish_nodes_relations_from_data(self, data, filters, entity_type_map):
"""Eshtablish relations among the extracted nodes."""
if self.config.graph_store.custom_prompt:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
).replace(
"CUSTOM_PROMPT", f"4. {self.config.graph_store.custom_prompt}"
),
},
{"role": "user", "content": data},
]
else:
messages = [
{
"role": "system",
"content": EXTRACT_RELATIONS_PROMPT.replace(
"USER_ID", filters["user_id"]
),
},
{
"role": "user",
"content": f"List of entities: {list(entity_type_map.keys())}. \n\nText: {data}",
},
]
_tools = [RELATIONS_TOOL]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [RELATIONS_STRUCT_TOOL]
extracted_entities = self.llm.generate_response(
messages=messages,
tools=_tools,
)
entities = []
if extracted_entities["tool_calls"]:
entities = extracted_entities["tool_calls"][0]["arguments"]["entities"]
entities = self._remove_spaces_from_entities(entities)
logger.debug(f"Extracted entities: {entities}")
return entities
def _search_graph_db(self, node_list, filters, limit=100):
"""Search similar nodes among and their respective incoming and outgoing relations."""
result_relations = []
for node in node_list:
n_embedding = self.embedding_model.embed(node)
cypher_query = """
MATCH (n:Entity {user_id: $user_id})-[r]->(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node1.user_id AS source, id(node1) AS source_id, type(r) AS relationship, id(r) AS relation_id, node2.user_id AS destination, id(node2) AS destination_id, similarity
UNION
MATCH (n:Entity {user_id: $user_id})<-[r]-(m:Entity)
WHERE n.embedding IS NOT NULL
WITH collect(n) AS nodes1, collect(m) AS nodes2, r
CALL node_similarity.cosine_pairwise("embedding", nodes1, nodes2)
YIELD node1, node2, similarity
WITH node1, node2, similarity, r
WHERE similarity >= $threshold
RETURN node2.name AS source, id(node2) AS source_id, type(r) AS relationship, id(r) AS relation_id, node1.name AS destination, id(node1) AS destination_id, similarity
ORDER BY similarity DESC
LIMIT $limit;
"""
params = {
"n_embedding": n_embedding,
"threshold": self.threshold,
"user_id": filters["user_id"],
"limit": limit,
}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def _get_delete_entities_from_search_output(self, search_output, data, filters):
"""Get the entities to be deleted from the search output."""
search_output_string = format_entities(search_output)
system_prompt, user_prompt = get_delete_messages(
search_output_string, data, filters["user_id"]
)
_tools = [DELETE_MEMORY_TOOL_GRAPH]
if self.llm_provider in ["azure_openai_structured", "openai_structured"]:
_tools = [
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
]
memory_updates = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
tools=_tools,
)
to_be_deleted = []
for item in memory_updates["tool_calls"]:
if item["name"] == "delete_graph_memory":
to_be_deleted.append(item["arguments"])
# in case if it is not in the correct format
to_be_deleted = self._remove_spaces_from_entities(to_be_deleted)
logger.debug(f"Deleted relationships: {to_be_deleted}")
return to_be_deleted
def _delete_entities(self, to_be_deleted, user_id):
"""Delete the entities from the graph."""
results = []
for item in to_be_deleted:
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# Delete the specific relationship between nodes
cypher = f"""
MATCH (n:Entity {{name: $source_name, user_id: $user_id}})
-[r:{relationship}]->
(m {{name: $dest_name, user_id: $user_id}})
DELETE r
RETURN
n.name AS source,
m.name AS target,
type(r) AS relationship
"""
params = {
"source_name": source,
"dest_name": destination,
"user_id": user_id,
}
result = self.graph.query(cypher, params=params)
results.append(result)
return results
# added Entity label to all nodes for vector search to work
def _add_entities(self, to_be_added, user_id, entity_type_map):
"""Add the new entities to the graph. Merge the nodes if they already exist."""
results = []
for item in to_be_added:
# entities
source = item["source"]
destination = item["destination"]
relationship = item["relationship"]
# types
source_type = entity_type_map.get(source, "unknown")
destination_type = entity_type_map.get(destination, "unknown")
# embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
# search for the nodes with the closest embeddings; this is basically
# comparison of one embedding to all embeddings in a graph -> vector
# search with cosine similarity metric
source_node_search_result = self._search_source_node(
source_embedding, user_id, threshold=0.9
)
destination_node_search_result = self._search_destination_node(
dest_embedding, user_id, threshold=0.9
)
# TODO: Create a cypher query and common params for all the cases
if not destination_node_search_result and source_node_search_result:
cypher = f"""
MATCH (source:Entity)
WHERE id(source) = $source_id
MERGE (destination:{destination_type}:Entity {{name: $destination_name, user_id: $user_id}})
ON CREATE SET
destination.created = timestamp(),
destination.embedding = $destination_embedding,
destination:Entity
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_name": destination,
"destination_embedding": dest_embedding,
"user_id": user_id,
}
elif destination_node_search_result and not source_node_search_result:
cypher = f"""
MATCH (destination:Entity)
WHERE id(destination) = $destination_id
MERGE (source:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
ON CREATE SET
source.created = timestamp(),
source.embedding = $source_embedding,
source:Entity
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"source_name": source,
"source_embedding": source_embedding,
"user_id": user_id,
}
elif source_node_search_result and destination_node_search_result:
cypher = f"""
MATCH (source:Entity)
WHERE id(source) = $source_id
MATCH (destination:Entity)
WHERE id(destination) = $destination_id
MERGE (source)-[r:{relationship}]->(destination)
ON CREATE SET
r.created_at = timestamp(),
r.updated_at = timestamp()
RETURN source.name AS source, type(r) AS relationship, destination.name AS target
"""
params = {
"source_id": source_node_search_result[0]["id(source_candidate)"],
"destination_id": destination_node_search_result[0][
"id(destination_candidate)"
],
"user_id": user_id,
}
else:
cypher = f"""
MERGE (n:{source_type}:Entity {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding, n:Entity
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type}:Entity {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding, m:Entity
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relationship}]->(m)
ON CREATE SET rel.created = timestamp()
RETURN n.name AS source, type(rel) AS relationship, m.name AS target
"""
params = {
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding,
"user_id": user_id,
}
result = self.graph.query(cypher, params=params)
results.append(result)
return results
def _remove_spaces_from_entities(self, entity_list):
for item in entity_list:
item["source"] = item["source"].lower().replace(" ", "_")
item["relationship"] = item["relationship"].lower().replace(" ", "_")
item["destination"] = item["destination"].lower().replace(" ", "_")
return entity_list
def _search_source_node(self, source_embedding, user_id, threshold=0.9):
cypher = """
CALL vector_search.search("memzero", 1, $source_embedding)
YIELD distance, node, similarity
WITH node AS source_candidate, similarity
WHERE source_candidate.user_id = $user_id AND similarity >= $threshold
RETURN id(source_candidate);
"""
params = {
"source_embedding": source_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result
def _search_destination_node(self, destination_embedding, user_id, threshold=0.9):
cypher = """
CALL vector_search.search("memzero", 1, $destination_embedding)
YIELD distance, node, similarity
WITH node AS destination_candidate, similarity
WHERE node.user_id = $user_id AND similarity >= $threshold
RETURN id(destination_candidate);
"""
params = {
"destination_embedding": destination_embedding,
"user_id": user_id,
"threshold": threshold,
}
result = self.graph.query(cypher, params=params)
return result