Added user_id support for graph memory
This commit is contained in:
@@ -177,7 +177,7 @@ class Memory(MemoryBase):
|
|||||||
self.graph.user_id = user_id
|
self.graph.user_id = user_id
|
||||||
else:
|
else:
|
||||||
self.graph.user_id = "USER"
|
self.graph.user_id = "USER"
|
||||||
added_entities = self.graph.add(data)
|
added_entities = self.graph.add(data, filters)
|
||||||
|
|
||||||
return {"message": "ok"}
|
return {"message": "ok"}
|
||||||
|
|
||||||
@@ -281,7 +281,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
if self.version == "v1.1":
|
if self.version == "v1.1":
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
graph_entities = self.graph.get_all()
|
graph_entities = self.graph.get_all(filters)
|
||||||
return {"memories": all_memories, "entities": graph_entities}
|
return {"memories": all_memories, "entities": graph_entities}
|
||||||
else:
|
else:
|
||||||
return {"memories" : all_memories}
|
return {"memories" : all_memories}
|
||||||
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
if self.version == "v1.1":
|
if self.version == "v1.1":
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
graph_entities = self.graph.search(query)
|
graph_entities = self.graph.search(query, filters)
|
||||||
return {"memories": original_memories, "entities": graph_entities}
|
return {"memories": original_memories, "entities": graph_entities}
|
||||||
else:
|
else:
|
||||||
return {"memories" : original_memories}
|
return {"memories" : original_memories}
|
||||||
@@ -442,7 +442,7 @@ class Memory(MemoryBase):
|
|||||||
self._delete_memory_tool(memory.id)
|
self._delete_memory_tool(memory.id)
|
||||||
|
|
||||||
if self.version == "v1.1" and self.enable_graph:
|
if self.version == "v1.1" and self.enable_graph:
|
||||||
self.graph.delete_all()
|
self.graph.delete_all(filters)
|
||||||
|
|
||||||
return {'message': 'Memories deleted successfully!'}
|
return {'message': 'Memories deleted successfully!'}
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class MemoryGraph:
|
|||||||
self.user_id = None
|
self.user_id = None
|
||||||
self.threshold = 0.7
|
self.threshold = 0.7
|
||||||
|
|
||||||
def add(self, data):
|
def add(self, data, filters):
|
||||||
"""
|
"""
|
||||||
Adds data to the graph.
|
Adds data to the graph.
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ class MemoryGraph:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# retrieve the search results
|
# retrieve the search results
|
||||||
search_output = self._search(data)
|
search_output = self._search(data, filters)
|
||||||
|
|
||||||
if self.config.graph_store.custom_prompt:
|
if self.config.graph_store.custom_prompt:
|
||||||
messages=[
|
messages=[
|
||||||
@@ -74,7 +74,7 @@ class MemoryGraph:
|
|||||||
if item['name'] == "add_graph_memory":
|
if item['name'] == "add_graph_memory":
|
||||||
to_be_added.append(item['arguments'])
|
to_be_added.append(item['arguments'])
|
||||||
elif item['name'] == "update_graph_memory":
|
elif item['name'] == "update_graph_memory":
|
||||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
|
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
|
||||||
elif item['name'] == "noop":
|
elif item['name'] == "noop":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -91,10 +91,10 @@ class MemoryGraph:
|
|||||||
|
|
||||||
# Updated Cypher query to include node types and embeddings
|
# Updated Cypher query to include node types and embeddings
|
||||||
cypher = f"""
|
cypher = f"""
|
||||||
MERGE (n:{source_type} {{name: $source_name}})
|
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
|
||||||
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
|
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
|
||||||
ON MATCH SET n.embedding = $source_embedding
|
ON MATCH SET n.embedding = $source_embedding
|
||||||
MERGE (m:{destination_type} {{name: $dest_name}})
|
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
|
||||||
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
|
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
|
||||||
ON MATCH SET m.embedding = $dest_embedding
|
ON MATCH SET m.embedding = $dest_embedding
|
||||||
MERGE (n)-[rel:{relation}]->(m)
|
MERGE (n)-[rel:{relation}]->(m)
|
||||||
@@ -106,16 +106,17 @@ class MemoryGraph:
|
|||||||
"source_name": source,
|
"source_name": source,
|
||||||
"dest_name": destination,
|
"dest_name": destination,
|
||||||
"source_embedding": source_embedding,
|
"source_embedding": source_embedding,
|
||||||
"dest_embedding": dest_embedding
|
"dest_embedding": dest_embedding,
|
||||||
|
"user_id": filters["user_id"]
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = self.graph.query(cypher, params=params)
|
_ = self.graph.query(cypher, params=params)
|
||||||
|
|
||||||
|
|
||||||
def _search(self, query):
|
def _search(self, query, filters):
|
||||||
search_results = self.llm.generate_response(
|
search_results = self.llm.generate_response(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
|
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
|
||||||
{"role": "user", "content": query},
|
{"role": "user", "content": query},
|
||||||
],
|
],
|
||||||
tools = [SEARCH_TOOL]
|
tools = [SEARCH_TOOL]
|
||||||
@@ -142,7 +143,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
cypher_query = """
|
cypher_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE n.embedding IS NOT NULL
|
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||||
WITH n,
|
WITH n,
|
||||||
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
||||||
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
||||||
@@ -152,7 +153,7 @@ class MemoryGraph:
|
|||||||
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
|
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
|
||||||
UNION
|
UNION
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE n.embedding IS NOT NULL
|
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||||
WITH n,
|
WITH n,
|
||||||
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
||||||
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
||||||
@@ -162,14 +163,14 @@ class MemoryGraph:
|
|||||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
||||||
ORDER BY similarity DESC
|
ORDER BY similarity DESC
|
||||||
"""
|
"""
|
||||||
params = {"n_embedding": n_embedding, "threshold": self.threshold}
|
params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
|
||||||
ans = self.graph.query(cypher_query, params=params)
|
ans = self.graph.query(cypher_query, params=params)
|
||||||
result_relations.extend(ans)
|
result_relations.extend(ans)
|
||||||
|
|
||||||
return result_relations
|
return result_relations
|
||||||
|
|
||||||
|
|
||||||
def search(self, query):
|
def search(self, query, filters):
|
||||||
"""
|
"""
|
||||||
Search for memories and related graph data.
|
Search for memories and related graph data.
|
||||||
|
|
||||||
@@ -182,7 +183,7 @@ class MemoryGraph:
|
|||||||
- "entities": List of related graph data based on the query.
|
- "entities": List of related graph data based on the query.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
search_output = self._search(query)
|
search_output = self._search(query, filters)
|
||||||
|
|
||||||
if not search_output:
|
if not search_output:
|
||||||
return []
|
return []
|
||||||
@@ -204,15 +205,16 @@ class MemoryGraph:
|
|||||||
return search_results
|
return search_results
|
||||||
|
|
||||||
|
|
||||||
def delete_all(self):
|
def delete_all(self, filters):
|
||||||
cypher = """
|
cypher = """
|
||||||
MATCH (n)
|
MATCH (n {user_id: $user_id})
|
||||||
DETACH DELETE n
|
DETACH DELETE n
|
||||||
"""
|
"""
|
||||||
self.graph.query(cypher)
|
params = {"user_id": filters["user_id"]}
|
||||||
|
self.graph.query(cypher, params=params)
|
||||||
|
|
||||||
|
|
||||||
def get_all(self):
|
def get_all(self, filters):
|
||||||
"""
|
"""
|
||||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||||
|
|
||||||
@@ -226,10 +228,10 @@ class MemoryGraph:
|
|||||||
|
|
||||||
# return all nodes and relationships
|
# return all nodes and relationships
|
||||||
query = """
|
query = """
|
||||||
MATCH (n)-[r]->(m)
|
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
|
||||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||||
"""
|
"""
|
||||||
results = self.graph.query(query)
|
results = self.graph.query(query, params={"user_id": filters["user_id"]})
|
||||||
|
|
||||||
final_results = []
|
final_results = []
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -242,7 +244,7 @@ class MemoryGraph:
|
|||||||
return final_results
|
return final_results
|
||||||
|
|
||||||
|
|
||||||
def _update_relationship(self, source, target, relationship):
|
def _update_relationship(self, source, target, relationship, filters):
|
||||||
"""
|
"""
|
||||||
Update or create a relationship between two nodes in the graph.
|
Update or create a relationship between two nodes in the graph.
|
||||||
|
|
||||||
@@ -258,25 +260,25 @@ class MemoryGraph:
|
|||||||
|
|
||||||
# Check if nodes exist and create them if they don't
|
# Check if nodes exist and create them if they don't
|
||||||
check_and_create_query = """
|
check_and_create_query = """
|
||||||
MERGE (n1 {name: $source})
|
MERGE (n1 {name: $source, user_id: $user_id})
|
||||||
MERGE (n2 {name: $target})
|
MERGE (n2 {name: $target, user_id: $user_id})
|
||||||
"""
|
"""
|
||||||
self.graph.query(check_and_create_query, params={"source": source, "target": target})
|
self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||||
|
|
||||||
# Delete any existing relationship between the nodes
|
# Delete any existing relationship between the nodes
|
||||||
delete_query = """
|
delete_query = """
|
||||||
MATCH (n1 {name: $source})-[r]->(n2 {name: $target})
|
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
|
||||||
DELETE r
|
DELETE r
|
||||||
"""
|
"""
|
||||||
self.graph.query(delete_query, params={"source": source, "target": target})
|
self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||||
|
|
||||||
# Create the new relationship
|
# Create the new relationship
|
||||||
create_query = f"""
|
create_query = f"""
|
||||||
MATCH (n1 {{name: $source}}), (n2 {{name: $target}})
|
MATCH (n1 {{name: $source, user_id: $user_id}}), (n2 {{name: $target, user_id: $user_id}})
|
||||||
CREATE (n1)-[r:{relationship}]->(n2)
|
CREATE (n1)-[r:{relationship}]->(n2)
|
||||||
RETURN n1, r, n2
|
RETURN n1, r, n2
|
||||||
"""
|
"""
|
||||||
result = self.graph.query(create_query, params={"source": source, "target": target})
|
result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||||
|
|
||||||
if not result:
|
if not result:
|
||||||
raise Exception(f"Failed to update or create relationship between {source} and {target}")
|
raise Exception(f"Failed to update or create relationship between {source} and {target}")
|
||||||
|
|||||||
Reference in New Issue
Block a user