Files
t6_mem0/mem0/memory/main_graph.py
Prateek Chhikara c64e0824da [Mem0] Integrate Graph Memory (#1718)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
2024-08-20 16:37:38 -07:00

285 lines
11 KiB
Python

from langchain_community.graphs import Neo4jGraph
from pydantic import BaseModel, Field
import json
from openai import OpenAI
from mem0.embeddings.openai import OpenAIEmbedding
from mem0.llms.openai import OpenAILLM
from mem0.graphs.utils import get_update_memory_messages, EXTRACT_ENTITIES_PROMPT
from mem0.graphs.tools import UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL
client = OpenAI()
class GraphData(BaseModel):
source: str = Field(..., description="The source node of the relationship")
target: str = Field(..., description="The target node of the relationship")
relationship: str = Field(..., description="The type of the relationship")
class Entities(BaseModel):
source_node: str
source_type: str
relation: str
destination_node: str
destination_type: str
class ADDQuery(BaseModel):
entities: list[Entities]
class SEARCHQuery(BaseModel):
nodes: list[str]
relations: list[str]
def get_embedding(text):
response = client.embeddings.create(
model="text-embedding-3-small",
input=text
)
return response.data[0].embedding
class MemoryGraph:
def __init__(self, config):
self.config = config
self.graph = Neo4jGraph(self.config.graph_store.config.url, self.config.graph_store.config.username, self.config.graph_store.config.password)
self.llm = OpenAILLM()
self.embedding_model = OpenAIEmbedding()
self.user_id = None
self.threshold = 0.7
self.model_name = "gpt-4o-2024-08-06"
def add(self, data):
"""
Adds data to the graph.
Args:
data (str): The data to add to the graph.
stored_memories (list): A list of stored memories.
Returns:
dict: A dictionary containing the entities added to the graph.
"""
# retrieve the search results
search_output = self._search(data)
extracted_entities = client.beta.chat.completions.parse(
model=self.model_name,
messages=[
{"role": "system", "content": EXTRACT_ENTITIES_PROMPT.replace("USER_ID", self.user_id)},
{"role": "user", "content": data},
],
response_format=ADDQuery,
temperature=0,
).choices[0].message.parsed.entities
update_memory_prompt = get_update_memory_messages(search_output, extracted_entities)
tools = [UPDATE_MEMORY_TOOL_GRAPH, ADD_MEMORY_TOOL_GRAPH, NOOP_TOOL]
memory_updates = client.beta.chat.completions.parse(
model=self.model_name,
messages=update_memory_prompt,
tools=tools,
temperature=0,
).choices[0].message.tool_calls
to_be_added = []
for item in memory_updates:
function_name = item.function.name
arguments = json.loads(item.function.arguments)
if function_name == "add_graph_memory":
to_be_added.append(arguments)
elif function_name == "update_graph_memory":
self._update_relationship(arguments['source'], arguments['destination'], arguments['relationship'])
elif function_name == "update_name":
self._update_name(arguments['name'])
elif function_name == "noop":
continue
new_relationships_response = []
for item in to_be_added:
source = item['source'].lower().replace(" ", "_")
source_type = item['source_type'].lower().replace(" ", "_")
relation = item['relationship'].lower().replace(" ", "_")
destination = item['destination'].lower().replace(" ", "_")
destination_type = item['destination_type'].lower().replace(" ", "_")
# Create embeddings
source_embedding = get_embedding(source)
dest_embedding = get_embedding(destination)
# Updated Cypher query to include node types and embeddings
cypher = f"""
MERGE (n:{source_type} {{name: $source_name}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type} {{name: $dest_name}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relation}]->(m)
ON CREATE SET rel.created = timestamp()
RETURN n, rel, m
"""
params = {
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding
}
result = self.graph.query(cypher, params=params)
def _search(self, query):
search_results = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
{"role": "user", "content": query},
],
response_format=SEARCHQuery,
).choices[0].message
node_list = search_results.parsed.nodes
relation_list = search_results.parsed.relations
node_list = [node.lower().replace(" ", "_") for node in node_list]
relation_list = [relation.lower().replace(" ", "_") for relation in relation_list]
result_relations = []
for node in node_list:
n_embedding = get_embedding(node)
cypher_query = """
MATCH (n)
WHERE n.embedding IS NOT NULL
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (n)-[r]->(m)
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
UNION
MATCH (n)
WHERE n.embedding IS NOT NULL
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold
MATCH (m)-[r]->(n)
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
"""
params = {"n_embedding": n_embedding, "threshold": self.threshold}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def search(self, query):
"""
Search for memories and related graph data.
Args:
query (str): Query to search for.
Returns:
dict: A dictionary containing:
- "contexts": List of search results from the base data store.
- "entities": List of related graph data based on the query.
"""
search_output = self._search(query)
search_results = []
for item in search_output:
search_results.append({
"source": item['source'],
"relation": item['relation'],
"destination": item['destination']
})
return search_results
def delete_all(self):
cypher = """
MATCH (n)
DETACH DELETE n
"""
self.graph.query(cypher)
def get_all(self):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
Args:
all_memories (list): A list of dictionaries, each containing:
Returns:
list: A list of dictionaries, each containing:
- 'contexts': The base data store response for each memory.
- 'entities': A list of strings representing the nodes and relationships
"""
# return all nodes and relationships
query = """
MATCH (n)-[r]->(m)
RETURN n.name AS source, type(r) AS relationship, m.name AS target
"""
results = self.graph.query(query)
final_results = []
for result in results:
final_results.append({
"source": result['source'],
"relationship": result['relationship'],
"target": result['target']
})
return final_results
def _update_relationship(self, source, target, relationship):
"""
Update or create a relationship between two nodes in the graph.
Args:
source (str): The name of the source node.
target (str): The name of the target node.
relationship (str): The type of the relationship.
Raises:
Exception: If the operation fails.
"""
relationship = relationship.lower().replace(" ", "_")
# Check if nodes exist and create them if they don't
check_and_create_query = """
MERGE (n1 {name: $source})
MERGE (n2 {name: $target})
"""
self.graph.query(check_and_create_query, params={"source": source, "target": target})
# Delete any existing relationship between the nodes
delete_query = """
MATCH (n1 {name: $source})-[r]->(n2 {name: $target})
DELETE r
"""
self.graph.query(delete_query, params={"source": source, "target": target})
# Create the new relationship
create_query = f"""
MATCH (n1 {{name: $source}}), (n2 {{name: $target}})
CREATE (n1)-[r:{relationship}]->(n2)
RETURN n1, r, n2
"""
result = self.graph.query(create_query, params={"source": source, "target": target})
if not result:
raise Exception(f"Failed to update or create relationship between {source} and {target}")