[Mem0] Integrate Graph Memory (#1718)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -4,9 +4,8 @@ import uuid
|
||||
import pytz
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
import warnings
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.llms.utils.tools import (
|
||||
ADD_MEMORY_TOOL,
|
||||
DELETE_MEMORY_TOOL,
|
||||
@@ -37,7 +36,15 @@ class Memory(MemoryBase):
|
||||
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
self.collection_name = self.config.vector_store.config.collection_name
|
||||
self.version = self.config.version
|
||||
|
||||
self.enable_graph = False
|
||||
|
||||
if self.version == "v1.1" and self.config.graph_store.config:
|
||||
from mem0.memory.main_graph import MemoryGraph
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
@classmethod
|
||||
@@ -164,6 +171,14 @@ class Memory(MemoryBase):
|
||||
{"memory_id": function_result, "function_name": function_name},
|
||||
)
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
if user_id:
|
||||
self.graph.user_id = user_id
|
||||
else:
|
||||
self.graph.user_id = "USER"
|
||||
added_entities = self.graph.add(data)
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
def get(self, memory_id):
|
||||
@@ -234,16 +249,8 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"hash",
|
||||
"data",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
}
|
||||
return [
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
all_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
@@ -271,6 +278,23 @@ class Memory(MemoryBase):
|
||||
}
|
||||
for mem in memories[0]
|
||||
]
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.get_all()
|
||||
return {"memories": all_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : all_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return all_memories
|
||||
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
@@ -302,7 +326,7 @@ class Memory(MemoryBase):
|
||||
"One of the filters: user_id, agent_id or run_id is required!"
|
||||
)
|
||||
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit})
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
query=embeddings, limit=limit, filters=filters
|
||||
@@ -318,7 +342,7 @@ class Memory(MemoryBase):
|
||||
"updated_at",
|
||||
}
|
||||
|
||||
return [
|
||||
original_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
@@ -348,6 +372,22 @@ class Memory(MemoryBase):
|
||||
for mem in memories
|
||||
]
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.search(query)
|
||||
return {"memories": original_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return original_memories
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
@@ -400,7 +440,11 @@ class Memory(MemoryBase):
|
||||
memories = self.vector_store.list(filters=filters)[0]
|
||||
for memory in memories:
|
||||
self._delete_memory_tool(memory.id)
|
||||
return {"message": "Memories deleted successfully!"}
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all()
|
||||
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
|
||||
284
mem0/memory/main_graph.py
Normal file
284
mem0/memory/main_graph.py
Normal file
@@ -0,0 +1,284 @@
|
||||
from langchain_community.graphs import Neo4jGraph
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.embeddings.openai import OpenAIEmbedding
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
|
||||
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
class GraphData(BaseModel):
|
||||
source: str = Field(..., description="The source node of the relationship")
|
||||
target: str = Field(..., description="The target node of the relationship")
|
||||
relationship: str = Field(..., description="The type of the relationship")
|
||||
|
||||
class Entities(BaseModel):
|
||||
source_node: str
|
||||
source_type: str
|
||||
relation: str
|
||||
destination_node: str
|
||||
destination_type: str
|
||||
|
||||
class ADDQuery(BaseModel):
|
||||
entities: list[Entities]
|
||||
|
||||
class SEARCHQuery(BaseModel):
|
||||
nodes: list[str]
|
||||
relations: list[str]
|
||||
|
||||
def get_embedding(text):
|
||||
response = client.embeddings.create(
|
||||
model="text-embedding-3-small",
|
||||
input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
class MemoryGraph:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
|
||||
|
||||
self.llm = OpenAILLM()
|
||||
self.embedding_model = OpenAIEmbedding()
|
||||
self.user_id = None
|
||||
self.threshold = 0.7
|
||||
self.model_name = "gpt-4o-2024-08-06"
|
||||
|
||||
def add(self, data):
|
||||
"""
|
||||
Adds data to the graph.
|
||||
|
||||
Args:
|
||||
data (str): The data to add to the graph.
|
||||
stored_memories (list): A list of stored memories.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the entities added to the graph.
|
||||
"""
|
||||
|
||||
# retrieve the search results
|
||||
search_output = self._search(data)
|
||||
|
||||
extracted_entities = client.beta.chat.completions.parse(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
|
||||
{"role": "user", "content": data},
|
||||
],
|
||||
response_format=ADDQuery,
|
||||
temperature=0,
|
||||
).choices[0].message.parsed.entities
|
||||
|
||||
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
|
||||
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
|
||||
|
||||
memory_updates = client.beta.chat.completions.parse(
|
||||
model=self.model_name,
|
||||
messages=update_memory_prompt,
|
||||
tools=tools,
|
||||
temperature=0,
|
||||
).choices[0].message.tool_calls
|
||||
|
||||
to_be_added = []
|
||||
for item in memory_updates:
|
||||
function_name = item.function.name
|
||||
arguments = json.loads(item.function.arguments)
|
||||
if function_name == "add_graph_memory":
|
||||
to_be_added.append(arguments)
|
||||
elif function_name == "update_graph_memory":
|
||||
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship'])
|
||||
elif function_name == "update_name":
|
||||
self._update_name(arguments['name'])
|
||||
elif function_name == "noop":
|
||||
continue
|
||||
|
||||
new_relationships_response = []
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
relation = item['relationship'].lower().replace(" ", "_")
|
||||
destination = item['destination'].lower().replace(" ", "_")
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = get_embedding(source)
|
||||
dest_embedding = get_embedding(destination)
|
||||
|
||||
# Updated Cypher query to include node types and embeddings
|
||||
cypher = f"""
|
||||
MERGE (n:{source_type} {{name: $source_name}})
|
||||
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
|
||||
ON MATCH SET n.embedding = $source_embedding
|
||||
MERGE (m:{destination_type} {{name: $dest_name}})
|
||||
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
|
||||
ON MATCH SET m.embedding = $dest_embedding
|
||||
MERGE (n)-[rel:{relation}]->(m)
|
||||
ON CREATE SET rel.created = timestamp()
|
||||
RETURN n, rel, m
|
||||
"""
|
||||
|
||||
params = {
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding
|
||||
}
|
||||
|
||||
result = self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
|
||||
def _search(self, query):
|
||||
search_results = client.beta.chat.completions.parse(
|
||||
model="gpt-4o-2024-08-06",
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
response_format=SEARCHQuery,
|
||||
).choices[0].message
|
||||
|
||||
node_list = search_results.parsed.nodes
|
||||
relation_list = search_results.parsed.relations
|
||||
|
||||
node_list = [node.lower().replace(" ", "_") for node in node_list]
|
||||
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
|
||||
|
||||
result_relations = []
|
||||
|
||||
for node in node_list:
|
||||
n_embedding = get_embedding(node)
|
||||
|
||||
cypher_query = """
|
||||
MATCH (n)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH n,
|
||||
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
||||
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
||||
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
|
||||
WHERE similarity >= $threshold
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
|
||||
UNION
|
||||
MATCH (n)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WITH n,
|
||||
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
||||
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
||||
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
|
||||
WHERE similarity >= $threshold
|
||||
MATCH (m)-[r]->(n)
|
||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
"""
|
||||
params = {"n_embedding": n_embedding, "threshold": self.threshold}
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
def search(self, query):
|
||||
"""
|
||||
Search for memories and related graph data.
|
||||
|
||||
Args:
|
||||
query (str): Query to search for.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
- "contexts": List of search results from the base data store.
|
||||
- "entities": List of related graph data based on the query.
|
||||
"""
|
||||
|
||||
search_output = self._search(query)
|
||||
search_results = []
|
||||
for item in search_output:
|
||||
search_results.append({
|
||||
"source": item['source'],
|
||||
"relation": item['relation'],
|
||||
"destination": item['destination']
|
||||
})
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def delete_all(self):
|
||||
cypher = """
|
||||
MATCH (n)
|
||||
DETACH DELETE n
|
||||
"""
|
||||
self.graph.query(cypher)
|
||||
|
||||
|
||||
def get_all(self):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
|
||||
Args:
|
||||
all_memories (list): A list of dictionaries, each containing:
|
||||
Returns:
|
||||
list: A list of dictionaries, each containing:
|
||||
- 'contexts': The base data store response for each memory.
|
||||
- 'entities': A list of strings representing the nodes and relationships
|
||||
"""
|
||||
|
||||
# return all nodes and relationships
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
"""
|
||||
results = self.graph.query(query)
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
final_results.append({
|
||||
"source": result['source'],
|
||||
"relationship": result['relationship'],
|
||||
"target": result['target']
|
||||
})
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
def _update_relationship(self, source, target, relationship):
|
||||
"""
|
||||
Update or create a relationship between two nodes in the graph.
|
||||
|
||||
Args:
|
||||
source (str): The name of the source node.
|
||||
target (str): The name of the target node.
|
||||
relationship (str): The type of the relationship.
|
||||
|
||||
Raises:
|
||||
Exception: If the operation fails.
|
||||
"""
|
||||
relationship = relationship.lower().replace(" ", "_")
|
||||
|
||||
# Check if nodes exist and create them if they don't
|
||||
check_and_create_query = """
|
||||
MERGE (n1 {name: $source})
|
||||
MERGE (n2 {name: $target})
|
||||
"""
|
||||
self.graph.query(check_and_create_query, params={"source": source, "target": target})
|
||||
|
||||
# Delete any existing relationship between the nodes
|
||||
delete_query = """
|
||||
MATCH (n1 {name: $source})-[r]->(n2 {name: $target})
|
||||
DELETE r
|
||||
"""
|
||||
self.graph.query(delete_query, params={"source": source, "target": target})
|
||||
|
||||
# Create the new relationship
|
||||
create_query = f"""
|
||||
MATCH (n1 {{name: $source}}), (n2 {{name: $target}})
|
||||
CREATE (n1)-[r:{relationship}]->(n2)
|
||||
RETURN n1, r, n2
|
||||
"""
|
||||
result = self.graph.query(create_query, params={"source": source, "target": target})
|
||||
|
||||
if not result:
|
||||
raise Exception(f"Failed to update or create relationship between {source} and {target}")
|
||||
@@ -53,7 +53,7 @@ def capture_event(event_name, memory_instance, additional_data=None):
|
||||
"vector_store": f"{memory_instance.vector_store.__class__.__module__}.{memory_instance.vector_store.__class__.__name__}",
|
||||
"llm": f"{memory_instance.llm.__class__.__module__}.{memory_instance.llm.__class__.__name__}",
|
||||
"embedding_model": f"{memory_instance.embedding_model.__class__.__module__}.{memory_instance.embedding_model.__class__.__name__}",
|
||||
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}",
|
||||
"function": f"{memory_instance.__class__.__module__}.{memory_instance.__class__.__name__}.{memory_instance.version}",
|
||||
}
|
||||
if additional_data:
|
||||
event_data.update(additional_data)
|
||||
|
||||
Reference in New Issue
Block a user