Added user_id support for graph memory

This commit is contained in:
Prateek Chhikara
2024-09-03 09:47:35 -07:00
committed by GitHub
parent d03ba0fc8a
commit 65056311a6
2 changed files with 33 additions and 31 deletions

View File

@@ -177,7 +177,7 @@ class Memory(MemoryBase):
self.graph.user_id = user_id self.graph.user_id = user_id
else: else:
self.graph.user_id = "USER" self.graph.user_id = "USER"
added_entities = self.graph.add(data) added_entities = self.graph.add(data, filters)
return {"message": "ok"} return {"message": "ok"}
@@ -281,7 +281,7 @@ class Memory(MemoryBase):
if self.version == "v1.1": if self.version == "v1.1":
if self.enable_graph: if self.enable_graph:
graph_entities = self.graph.get_all() graph_entities = self.graph.get_all(filters)
return {"memories": all_memories, "entities": graph_entities} return {"memories": all_memories, "entities": graph_entities}
else: else:
return {"memories" : all_memories} return {"memories" : all_memories}
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
if self.version == "v1.1": if self.version == "v1.1":
if self.enable_graph: if self.enable_graph:
graph_entities = self.graph.search(query) graph_entities = self.graph.search(query, filters)
return {"memories": original_memories, "entities": graph_entities} return {"memories": original_memories, "entities": graph_entities}
else: else:
return {"memories" : original_memories} return {"memories" : original_memories}
@@ -442,7 +442,7 @@ class Memory(MemoryBase):
self._delete_memory_tool(memory.id) self._delete_memory_tool(memory.id)
if self.version == "v1.1" and self.enable_graph: if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all() self.graph.delete_all(filters)
return {'message': 'Memories deleted successfully!'} return {'message': 'Memories deleted successfully!'}

View File

@@ -25,7 +25,7 @@ class MemoryGraph:
self.user_id = None self.user_id = None
self.threshold = 0.7 self.threshold = 0.7
def add(self, data): def add(self, data, filters):
""" """
Adds data to the graph. Adds data to the graph.
@@ -38,7 +38,7 @@ class MemoryGraph:
""" """
# retrieve the search results # retrieve the search results
search_output = self._search(data) search_output = self._search(data, filters)
if self.config.graph_store.custom_prompt: if self.config.graph_store.custom_prompt:
messages=[ messages=[
@@ -74,7 +74,7 @@ class MemoryGraph:
if item['name'] == "add_graph_memory": if item['name'] == "add_graph_memory":
to_be_added.append(item['arguments']) to_be_added.append(item['arguments'])
elif item['name'] == "update_graph_memory": elif item['name'] == "update_graph_memory":
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship']) self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
elif item['name'] == "noop": elif item['name'] == "noop":
continue continue
@@ -91,10 +91,10 @@ class MemoryGraph:
# Updated Cypher query to include node types and embeddings # Updated Cypher query to include node types and embeddings
cypher = f""" cypher = f"""
MERGE (n:{source_type} {{name: $source_name}}) MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
ON MATCH SET n.embedding = $source_embedding ON MATCH SET n.embedding = $source_embedding
MERGE (m:{destination_type} {{name: $dest_name}}) MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
ON MATCH SET m.embedding = $dest_embedding ON MATCH SET m.embedding = $dest_embedding
MERGE (n)-[rel:{relation}]->(m) MERGE (n)-[rel:{relation}]->(m)
@@ -106,16 +106,17 @@ class MemoryGraph:
"source_name": source, "source_name": source,
"dest_name": destination, "dest_name": destination,
"source_embedding": source_embedding, "source_embedding": source_embedding,
"dest_embedding": dest_embedding "dest_embedding": dest_embedding,
"user_id": filters["user_id"]
} }
_ = self.graph.query(cypher, params=params) _ = self.graph.query(cypher, params=params)
def _search(self, query): def _search(self, query, filters):
search_results = self.llm.generate_response( search_results = self.llm.generate_response(
messages=[ messages=[
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."}, {"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
{"role": "user", "content": query}, {"role": "user", "content": query},
], ],
tools = [SEARCH_TOOL] tools = [SEARCH_TOOL]
@@ -142,7 +143,7 @@ class MemoryGraph:
cypher_query = """ cypher_query = """
MATCH (n) MATCH (n)
WHERE n.embedding IS NOT NULL WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n, WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
@@ -152,7 +153,7 @@ class MemoryGraph:
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
UNION UNION
MATCH (n) MATCH (n)
WHERE n.embedding IS NOT NULL WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n, WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
@@ -162,14 +163,14 @@ class MemoryGraph:
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
ORDER BY similarity DESC ORDER BY similarity DESC
""" """
params = {"n_embedding": n_embedding, "threshold": self.threshold} params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
ans = self.graph.query(cypher_query, params=params) ans = self.graph.query(cypher_query, params=params)
result_relations.extend(ans) result_relations.extend(ans)
return result_relations return result_relations
def search(self, query): def search(self, query, filters):
""" """
Search for memories and related graph data. Search for memories and related graph data.
@@ -182,7 +183,7 @@ class MemoryGraph:
- "entities": List of related graph data based on the query. - "entities": List of related graph data based on the query.
""" """
search_output = self._search(query) search_output = self._search(query, filters)
if not search_output: if not search_output:
return [] return []
@@ -204,15 +205,16 @@ class MemoryGraph:
return search_results return search_results
def delete_all(self): def delete_all(self, filters):
cypher = """ cypher = """
MATCH (n) MATCH (n {user_id: $user_id})
DETACH DELETE n DETACH DELETE n
""" """
self.graph.query(cypher) params = {"user_id": filters["user_id"]}
self.graph.query(cypher, params=params)
def get_all(self): def get_all(self, filters):
""" """
Retrieves all nodes and relationships from the graph database based on optional filtering criteria. Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
@@ -226,10 +228,10 @@ class MemoryGraph:
# return all nodes and relationships # return all nodes and relationships
query = """ query = """
MATCH (n)-[r]->(m) MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
RETURN n.name AS source, type(r) AS relationship, m.name AS target RETURN n.name AS source, type(r) AS relationship, m.name AS target
""" """
results = self.graph.query(query) results = self.graph.query(query, params={"user_id": filters["user_id"]})
final_results = [] final_results = []
for result in results: for result in results:
@@ -242,7 +244,7 @@ class MemoryGraph:
return final_results return final_results
def _update_relationship(self, source, target, relationship): def _update_relationship(self, source, target, relationship, filters):
""" """
Update or create a relationship between two nodes in the graph. Update or create a relationship between two nodes in the graph.
@@ -258,25 +260,25 @@ class MemoryGraph:
# Check if nodes exist and create them if they don't # Check if nodes exist and create them if they don't
check_and_create_query = """ check_and_create_query = """
MERGE (n1 {name: $source}) MERGE (n1 {name: $source, user_id: $user_id})
MERGE (n2 {name: $target}) MERGE (n2 {name: $target, user_id: $user_id})
""" """
self.graph.query(check_and_create_query, params={"source": source, "target": target}) self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
# Delete any existing relationship between the nodes # Delete any existing relationship between the nodes
delete_query = """ delete_query = """
MATCH (n1 {name: $source})-[r]->(n2 {name: $target}) MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
DELETE r DELETE r
""" """
self.graph.query(delete_query, params={"source": source, "target": target}) self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
# Create the new relationship # Create the new relationship
create_query = f""" create_query = f"""
MATCH (n1 {{name: $source}}), (n2 {{name: $target}}) MATCH (n1 {{name: $source, user_id: $user_id}}), (n2 {{name: $target, user_id: $user_id}})
CREATE (n1)-[r:{relationship}]->(n2) CREATE (n1)-[r:{relationship}]->(n2)
RETURN n1, r, n2 RETURN n1, r, n2
""" """
result = self.graph.query(create_query, params={"source": source, "target": target}) result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
if not result: if not result:
raise Exception(f"Failed to update or create relationship between {source} and {target}") raise Exception(f"Failed to update or create relationship between {source} and {target}")