Added user_id support for graph memory
This commit is contained in:
@@ -177,7 +177,7 @@ class Memory(MemoryBase):
|
||||
self.graph.user_id = user_id
|
||||
else:
|
||||
self.graph.user_id = "USER"
|
||||
added_entities = self.graph.add(data)
|
||||
added_entities = self.graph.add(data, filters)
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
@@ -281,7 +281,7 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.get_all()
|
||||
graph_entities = self.graph.get_all(filters)
|
||||
return {"memories": all_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : all_memories}
|
||||
@@ -374,7 +374,7 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.search(query)
|
||||
graph_entities = self.graph.search(query, filters)
|
||||
return {"memories": original_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : original_memories}
|
||||
@@ -442,7 +442,7 @@ class Memory(MemoryBase):
|
||||
self._delete_memory_tool(memory.id)
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all()
|
||||
self.graph.delete_all(filters)
|
||||
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class MemoryGraph:
|
||||
self.user_id = None
|
||||
self.threshold = 0.7
|
||||
|
||||
def add(self, data):
|
||||
def add(self, data, filters):
|
||||
"""
|
||||
Adds data to the graph.
|
||||
|
||||
@@ -38,7 +38,7 @@ class MemoryGraph:
|
||||
"""
|
||||
|
||||
# retrieve the search results
|
||||
search_output = self._search(data)
|
||||
search_output = self._search(data, filters)
|
||||
|
||||
if self.config.graph_store.custom_prompt:
|
||||
messages=[
|
||||
@@ -74,7 +74,7 @@ class MemoryGraph:
|
||||
if item['name'] == "add_graph_memory":
|
||||
to_be_added.append(item['arguments'])
|
||||
elif item['name'] == "update_graph_memory":
|
||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'])
|
||||
self._update_relationship(item['arguments']['source'], item['arguments']['destination'], item['arguments']['relationship'], filters)
|
||||
elif item['name'] == "noop":
|
||||
continue
|
||||
|
||||
@@ -91,10 +91,10 @@ class MemoryGraph:
|
||||
|
||||
# Updated Cypher query to include node types and embeddings
|
||||
cypher = f"""
|
||||
MERGE (n:{source_type} {{name: $source_name}})
|
||||
MERGE (n:{source_type} {{name: $source_name, user_id: $user_id}})
|
||||
ON CREATE SET n.created = timestamp(), n.embedding = $source_embedding
|
||||
ON MATCH SET n.embedding = $source_embedding
|
||||
MERGE (m:{destination_type} {{name: $dest_name}})
|
||||
MERGE (m:{destination_type} {{name: $dest_name, user_id: $user_id}})
|
||||
ON CREATE SET m.created = timestamp(), m.embedding = $dest_embedding
|
||||
ON MATCH SET m.embedding = $dest_embedding
|
||||
MERGE (n)-[rel:{relation}]->(m)
|
||||
@@ -106,16 +106,17 @@ class MemoryGraph:
|
||||
"source_name": source,
|
||||
"dest_name": destination,
|
||||
"source_embedding": source_embedding,
|
||||
"dest_embedding": dest_embedding
|
||||
"dest_embedding": dest_embedding,
|
||||
"user_id": filters["user_id"]
|
||||
}
|
||||
|
||||
_ = self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def _search(self, query):
|
||||
def _search(self, query, filters):
|
||||
search_results = self.llm.generate_response(
|
||||
messages=[
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {self.user_id} as the source node. Extract the entities."},
|
||||
{"role": "system", "content": f"You are a smart assistant who understands the entities, their types, and relations in a given text. If user message contains self reference such as 'I', 'me', 'my' etc. then use {filters['user_id']} as the source node. Extract the entities."},
|
||||
{"role": "user", "content": query},
|
||||
],
|
||||
tools = [SEARCH_TOOL]
|
||||
@@ -142,7 +143,7 @@ class MemoryGraph:
|
||||
|
||||
cypher_query = """
|
||||
MATCH (n)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||
WITH n,
|
||||
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
||||
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
||||
@@ -152,7 +153,7 @@ class MemoryGraph:
|
||||
RETURN n.name AS source, elementId(n) AS source_id, type(r) AS relation, elementId(r) AS relation_id, m.name AS destination, elementId(m) AS destination_id, similarity
|
||||
UNION
|
||||
MATCH (n)
|
||||
WHERE n.embedding IS NOT NULL
|
||||
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
|
||||
WITH n,
|
||||
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
|
||||
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
|
||||
@@ -162,14 +163,14 @@ class MemoryGraph:
|
||||
RETURN m.name AS source, elementId(m) AS source_id, type(r) AS relation, elementId(r) AS relation_id, n.name AS destination, elementId(n) AS destination_id, similarity
|
||||
ORDER BY similarity DESC
|
||||
"""
|
||||
params = {"n_embedding": n_embedding, "threshold": self.threshold}
|
||||
params = {"n_embedding": n_embedding, "threshold": self.threshold, "user_id": filters["user_id"]}
|
||||
ans = self.graph.query(cypher_query, params=params)
|
||||
result_relations.extend(ans)
|
||||
|
||||
return result_relations
|
||||
|
||||
|
||||
def search(self, query):
|
||||
def search(self, query, filters):
|
||||
"""
|
||||
Search for memories and related graph data.
|
||||
|
||||
@@ -182,7 +183,7 @@ class MemoryGraph:
|
||||
- "entities": List of related graph data based on the query.
|
||||
"""
|
||||
|
||||
search_output = self._search(query)
|
||||
search_output = self._search(query, filters)
|
||||
|
||||
if not search_output:
|
||||
return []
|
||||
@@ -204,15 +205,16 @@ class MemoryGraph:
|
||||
return search_results
|
||||
|
||||
|
||||
def delete_all(self):
|
||||
def delete_all(self, filters):
|
||||
cypher = """
|
||||
MATCH (n)
|
||||
MATCH (n {user_id: $user_id})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
self.graph.query(cypher)
|
||||
params = {"user_id": filters["user_id"]}
|
||||
self.graph.query(cypher, params=params)
|
||||
|
||||
|
||||
def get_all(self):
|
||||
def get_all(self, filters):
|
||||
"""
|
||||
Retrieves all nodes and relationships from the graph database based on optional filtering criteria.
|
||||
|
||||
@@ -226,10 +228,10 @@ class MemoryGraph:
|
||||
|
||||
# return all nodes and relationships
|
||||
query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
MATCH (n {user_id: $user_id})-[r]->(m {user_id: $user_id})
|
||||
RETURN n.name AS source, type(r) AS relationship, m.name AS target
|
||||
"""
|
||||
results = self.graph.query(query)
|
||||
results = self.graph.query(query, params={"user_id": filters["user_id"]})
|
||||
|
||||
final_results = []
|
||||
for result in results:
|
||||
@@ -242,7 +244,7 @@ class MemoryGraph:
|
||||
return final_results
|
||||
|
||||
|
||||
def _update_relationship(self, source, target, relationship):
|
||||
def _update_relationship(self, source, target, relationship, filters):
|
||||
"""
|
||||
Update or create a relationship between two nodes in the graph.
|
||||
|
||||
@@ -258,25 +260,25 @@ class MemoryGraph:
|
||||
|
||||
# Check if nodes exist and create them if they don't
|
||||
check_and_create_query = """
|
||||
MERGE (n1 {name: $source})
|
||||
MERGE (n2 {name: $target})
|
||||
MERGE (n1 {name: $source, user_id: $user_id})
|
||||
MERGE (n2 {name: $target, user_id: $user_id})
|
||||
"""
|
||||
self.graph.query(check_and_create_query, params={"source": source, "target": target})
|
||||
self.graph.query(check_and_create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
|
||||
# Delete any existing relationship between the nodes
|
||||
delete_query = """
|
||||
MATCH (n1 {name: $source})-[r]->(n2 {name: $target})
|
||||
MATCH (n1 {name: $source, user_id: $user_id})-[r]->(n2 {name: $target, user_id: $user_id})
|
||||
DELETE r
|
||||
"""
|
||||
self.graph.query(delete_query, params={"source": source, "target": target})
|
||||
self.graph.query(delete_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
|
||||
# Create the new relationship
|
||||
create_query = f"""
|
||||
MATCH (n1 {{name: $source}}), (n2 {{name: $target}})
|
||||
MATCH (n1 {{name: $source, user_id: $user_id}}), (n2 {{name: $target, user_id: $user_id}})
|
||||
CREATE (n1)-[r:{relationship}]->(n2)
|
||||
RETURN n1, r, n2
|
||||
"""
|
||||
result = self.graph.query(create_query, params={"source": source, "target": target})
|
||||
result = self.graph.query(create_query, params={"source": source, "target": target, "user_id": filters["user_id"]})
|
||||
|
||||
if not result:
|
||||
raise Exception(f"Failed to update or create relationship between {source} and {target}")
|
||||
|
||||
Reference in New Issue
Block a user