Added user_id support for graph memory

This commit is contained in:
Prateek Chhikara
2024-09-03 09:47:35 -07:00
committed by GitHub
parent d03ba0fc8a
commit 65056311a6
2 changed files with 33 additions and 31 deletions

View File

@@ -177,7 +177,7 @@ class Memory(MemoryBase):
self.graph.user_id = user_id
else:
self.graph.user_id = "USER"
added_entities = self.graph.add(data)
added_entities = self.graph.add(data, filters)
return {"message": "ok"}
@@ -281,7 +281,7 @@ class Memory(MemoryBase):
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.get_all()
graph_entities = self.graph.get_all(filters)
return {"memories": all_memories, "entities": graph_entities}
else:
return {"memories" : all_memories}
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.search(query)
graph_entities = self.graph.search(query, filters)
return {"memories": original_memories, "entities": graph_entities}
else:
return {"memories" : original_memories}
@@ -442,7 +442,7 @@ class Memory(MemoryBase):
self._delete_memory_tool(memory.id)
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all()
self.graph.delete_all(filters)
return {'message': 'Memories deleted successfully!'}

View File

@@ -25,7 +25,7 @@ class MemoryGraph:
self.user_id = None
self.threshold = 0.7
def add(self, data):
def add(self, data, filters):
"""
Adds data to the graph.
@@ -38,7 +38,7 @@ class MemoryGraph:
"""
# retrieve the search results
search_output = self._search(data)
search_output = self._search(data, filters)
if self.config.graph_store.custom_prompt:
messages=[
@@ -74,7 +74,7 @@ class MemoryGraph:
if item['name'] == "add_graph_memory":
to_be_added.append(item['arguments'])
elif item['name'] == "update_graph_memory":
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
elif item['name'] == "noop":
continue
@@ -91,10 +91,10 @@ class MemoryGraph:
# Updated Cypher query to include node types and embeddings
cypher = f"""
MERGE (n:{source_type} {{name: $source_name}})
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type} {{name: $dest_name}})
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relation}]->(m)
@@ -106,16 +106,17 @@ class MemoryGraph:
"source_name": source,
"dest_name": destination,
"source_embedding": source_embedding,
"dest_embedding": dest_embedding
"dest_embedding": dest_embedding,
"user_id": filters["user_id"]
}
_ = self.graph.query(cypher, params=params)
def _search(self, query):
def _search(self, query, filters):
search_results = self.llm.generate_response(
messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
{"role": "user", "content": query},
],
tools = [SEARCH_TOOL]
@@ -142,7 +143,7 @@ class MemoryGraph:
cypher_query = """
MATCH (n)
WHERE n.embedding IS NOT NULL
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
@@ -152,7 +153,7 @@ class MemoryGraph:
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
UNION
MATCH (n)
WHERE n.embedding IS NOT NULL
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
@@ -162,14 +163,14 @@ class MemoryGraph:
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC
"""
params = {"n_embedding": n_embedding, "threshold": self.threshold}
params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans)
return result_relations
def search(self, query):
def search(self, query, filters):
"""
Search for memories and related graph data.
@@ -182,7 +183,7 @@ class MemoryGraph:
- "entities": List of related graph data based on the query.
"""
search_output = self._search(query)
search_output = self._search(query, filters)
if not search_output:
return []
@@ -204,15 +205,16 @@ class MemoryGraph:
return search_results
def delete_all(self):
def delete_all(self, filters):
cypher = """
MATCH (n)
MATCH (n {user_id: $user_id})
DETACH DELETE n
"""
self.graph.query(cypher)
params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self):
def get_all(self, filters):
"""
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -226,10 +228,10 @@ class MemoryGraph:
# return all nodes and relationships
query = """
MATCH (n)-[r]->(m)
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target
"""
results = self.graph.query(query)
results = self.graph.query(query, params={"user_id": filters["user_id"]})
final_results = []
for result in results:
@@ -242,7 +244,7 @@ class MemoryGraph:
return final_results
def _update_relationship(self, source, target, relationship):
def _update_relationship(self, source, target, relationship, filters):
"""
Update or create a relationship between two nodes in the graph.
@@ -258,25 +260,25 @@ class MemoryGraph:
# Check if nodes exist and create them if they don't
check_and_create_query = """
MERGE (n1 {name: $source})
MERGE (n2 {name: $target})
MERGE (n1 {name: $source, user_id: $user_id})
MERGE (n2 {name: $target, user_id: $user_id})
"""
self.graph.query(check_and_create_query, params={"source": source, "target": target})
self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
# Delete any existing relationship between the nodes
delete_query = """
MATCH (n1 {name: $source})-[r]->(n2 {name: $target})
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
DELETE r
"""
self.graph.query(delete_query, params={"source": source, "target": target})
self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
# Create the new relationship
create_query = f"""
MATCH (n1 {{name: $source}}), (n2 {{name: $target}})
MATCH (n1 {{name: $source, user_id: $user_id}}), (n2 {{name: $target, user_id: $user_id}})
CREATE (n1)-[r:{relationship}]->(n2)
RETURN n1, r, n2
"""
result = self.graph.query(create_query, params={"source": source, "target": target})
result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
if not result:
raise Exception(f"Failed to update or create relationship between {source} and {target}")