[Mem0] Integrate Graph Memory (#1718)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Prateek Chhikara
2024-08-20 16:37:38 -07:00
committed by GitHub
parent 9b7a882d57
commit c64e0824da
22 changed files with 867 additions and 26 deletions

View File

@@ -4,9 +4,8 @@ import uuid
import pytz
from datetime import datetime
from typing import Any, Dict
import warnings
from pydantic import ValidationError
from mem0.llms.utils.tools import (
ADD_MEMORY_TOOL,
DELETE_MEMORY_TOOL,
@@ -37,7 +36,15 @@ class Memory(MemoryBase):
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.version = self.config.version
self.enable_graph = False
if self.version == "v1.1" and self.config.graph_store.config:
from mem0.memory.main_graph import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
@classmethod
@@ -164,6 +171,14 @@ class Memory(MemoryBase):
{"memory_id": function_result, "function_name": function_name},
)
capture_event("mem0.add", self)
if self.version == "v1.1" and self.enable_graph:
if user_id:
self.graph.user_id = user_id
else:
self.graph.user_id = "USER"
added_entities = self.graph.add(data)
return {"message": "ok"}
def get(self, memory_id):
@@ -234,16 +249,8 @@ class Memory(MemoryBase):
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
return [
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
all_memories = [
{
**MemoryItem(
id=mem.id,
@@ -271,6 +278,23 @@ class Memory(MemoryBase):
}
for mem in memories[0]
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.get_all()
return {"memories": all_memories, "entities": graph_entities}
else:
return {"memories" : all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
@@ -302,7 +326,7 @@ class Memory(MemoryBase):
"One of the filters: user_id, agent_id or run_id is required!"
)
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
query=embeddings, limit=limit, filters=filters
@@ -318,7 +342,7 @@ class Memory(MemoryBase):
"updated_at",
}
return [
original_memories = [
{
**MemoryItem(
id=mem.id,
@@ -348,6 +372,22 @@ class Memory(MemoryBase):
for mem in memories
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.search(query)
return {"memories": original_memories, "entities": graph_entities}
else:
return {"memories" : original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
@@ -400,7 +440,11 @@ class Memory(MemoryBase):
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory_tool(memory.id)
return {"message": "Memories deleted successfully!"}
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all()
return {'message': 'Memories deleted successfully!'}
def history(self, memory_id):
"""

284
mem0/memory/main_graph.py Normal file
View File

@@ -0,0 +1,284 @@
from langchain_community.graphs import Neo4jGraph
from pydantic import BaseModel, Field
import json
from openai import OpenAI
from mem0.embeddings.openai import OpenAIEmbedding
from mem0.llms.openai import OpenAILLM
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL
client = OpenAI()
class GraphData(BaseModel):
source: str = Field(..., description="The source node of the relationship")
target: str = Field(..., description="The target node of the relationship")
relationship: str = Field(..., description="The type of the relationship")
class Entities(BaseModel):
source_node: str
source_type: str
relation: str
destination_node: str
destination_type: str
class ADDQuery(BaseModel):
entities: list[Entities]
class SEARCHQuery(BaseModel):
nodes: list[str]
relations: list[str]
def get_embedding(text):
response = client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
class MemoryGraph:
def __init__(self, config):
self.config = config
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
self.llm = OpenAILLM()
self.embedding_model = OpenAIEmbedding()
self.user_id = None
self.threshold = 0.7
self.model_name = "gpt-4o-2024-08-06"
def add(self, data):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
stored_memories (list): A list of stored memories.
Returns:
dict: A dictionary containing the entities added to the graph.
"""
# retrieve the search results
search_output = self._search(data)
extracted_entities = client.beta.chat.completions.parse(
model=self.model_name,
messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data},
],
response_format=ADDQuery,
temperature=0,
).choices[0].message.parsed.entities
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
memory_updates = client.beta.chat.completions.parse(
model=self.model_name,
messages=update_memory_prompt,
tools=tools,
temperature=0,
).choices[0].message.tool_calls
to_be_added = []
for item in memory_updates:
function_name = item.function.name
arguments = json.loads(item.function.arguments)
if function_name == "add_graph_memory":
to_be_added.append(arguments)
elif function_name == "update_graph_memory":
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship'])
elif function_name == "update_name":
self._update_name(arguments['name'])
elif function_name == "noop":
continue
new_relationships_response = []
for item in to_be_added:
source = item['source'].lower().replace(" ", "_")
source_type = item['source_type'].lower().replace(" ", "_")
relation = item['relationship'].lower().replace(" ", "_")
destination = item['destination'].lower().replace(" ", "_")
destination_type = item['destination_type'].lower().replace(" ", "_")
# Create embeddings
source_embedding = get_embedding(source)
dest_embedding = get_embedding(destination)
# Updated Cypher query to include node types and embeddings
cypher = f"""
MERGE (n:{source_type} {{name: $source_name}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type} {{name: $dest_name}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relation}]->(m)
ON CREATE SET rel.created = timestamp()
RETURN n, rel, m
"""
params = {
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding
}
result = self.graph.query(cypher, params=params)
def _search(self, query):
search_results = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
{"role": "user", "content": query},
],
response_format=SEARCHQuery,
).choices[0].message
node_list = search_results.parsed.nodes
relation_list = search_results.parsed.relations
node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
result_relations = []
for node in node_list:
n_embedding = get_embedding(node)
cypher_query = """
MATCH (n)
WHERE n.embedding IS NOT NULL
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
UNION
MATCH (n)
WHERE n.embedding IS NOT NULL
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
"""
params = {"n_embedding": n_embedding, "threshold": self.threshold}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def search(self, query):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
search_output = self._search(query)
search_results = []
for item in search_output:
search_results.append({
"source": item['source'],
"relation": item['relation'],
"destination": item['destination']
})
return search_results
def delete_all(self):
cypher = """
MATCH (n)
DETACH DELETE n
"""
self.graph.query(cypher)
def get_all(self):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
all_memories (list): A list of dictionaries, each containing:
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
# return all nodes and relationships
query = """
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
"""
results = self.graph.query(query)
final_results = []
for result in results:
final_results.append({
"source": result['source'],
"relationship": result['relationship'],
"target": result['target']
})
return final_results
def _update_relationship(self, source, target, relationship):
"""
Update or create a relationship between two nodes in the graph.
Args:
source (str): The name of the source node.
target (str): The name of the target node.
relationship (str): The type of the relationship.
Raises:
Exception: If the operation fails.
"""
relationship = relationship.lower().replace(" ", "_")
# Check if nodes exist and create them if they don't
check_and_create_query = """
MERGE (n1 {name: $source})
MERGE (n2 {name: $target})
"""
self.graph.query(check_and_create_query, params={"source": source, "target": target})
# Delete any existing relationship between the nodes
delete_query = """
MATCH (n1 {name: $source})-[r]->(n2 {name: $target})
DELETE r
"""
self.graph.query(delete_query, params={"source": source, "target": target})
# Create the new relationship
create_query = f"""
MATCH (n1 {{name: $source}}), (n2 {{name: $target}})
CREATE (n1)-[r:{relationship}]->(n2)
RETURN n1, r, n2
"""
result = self.graph.query(create_query, params={"source": source, "target": target})
if not result:
raise Exception(f"Failed to update or create relationship between {source} and {target}")

View File

@@ -53,7 +53,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}",
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.version}",
}
if additional_data:
event_data.update(additional_data)