Graph memory bug fix (#1932)

This commit is contained in:
Prateek Chhikara
2024-09-30 16:55:01 -07:00
committed by GitHub
parent f324462cc3
commit 0d45c61aa3
3 changed files with 16 additions and 8 deletions

View File

@@ -201,9 +201,9 @@ class MemoryGraph:
cypher_query = """ cypher_query = """
MATCH (n) MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n, WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold WHERE similarity >= $threshold
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
@@ -211,9 +211,9 @@ class MemoryGraph:
UNION UNION
MATCH (n) MATCH (n)
WHERE n.embedding IS NOT NULL AND n.user_id = $user_id WHERE n.embedding IS NOT NULL AND n.user_id = $user_id
WITH n, WITH n,
round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) /
(sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) *
sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity
WHERE similarity >= $threshold WHERE similarity >= $threshold
MATCH (m)-[r]->(n) MATCH (m)-[r]->(n)

View File

@@ -215,10 +215,14 @@ class Memory(MemoryBase):
if self.version == "v1.1" and self.enable_graph: if self.version == "v1.1" and self.enable_graph:
if filters["user_id"]: if filters["user_id"]:
self.graph.user_id = filters["user_id"] self.graph.user_id = filters["user_id"]
elif filters["agent_id"]:
self.graph.agent_id = filters["agent_id"]
elif filters["run_id"]:
self.graph.run_id = filters["run_id"]
else: else:
self.graph.user_id = "USER" self.graph.user_id = "USER"
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
self.graph.add(data, filters) added_entities = self.graph.add(data, filters)
return added_entities return added_entities
@@ -289,6 +293,8 @@ class Memory(MemoryBase):
executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None
) )
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
all_memories = future_memories.result() all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None graph_entities = future_graph_entities.result() if future_graph_entities else None
@@ -379,6 +385,8 @@ class Memory(MemoryBase):
else None else None
) )
concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories])
original_memories = future_memories.result() original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None graph_entities = future_graph_entities.result() if future_graph_entities else None

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "mem0ai" name = "mem0ai"
version = "0.1.16" version = "0.1.17"
description = "Long-term memory for AI Agents" description = "Long-term memory for AI Agents"
authors = ["Mem0 <founders@mem0.ai>"] authors = ["Mem0 <founders@mem0.ai>"]
exclude = [ exclude = [