From 0d45c61aa3584935092e59ea4c781a2d388c0619 Mon Sep 17 00:00:00 2001 From: Prateek Chhikara <46902268+prateekchhikara@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:55:01 -0700 Subject: [PATCH] Graph memory bug fix (#1932) --- mem0/memory/graph_memory.py | 12 ++++++------ mem0/memory/main.py | 10 +++++++++- pyproject.toml | 2 +- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index 2e09399e..f25d7552 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -201,9 +201,9 @@ class MemoryGraph: cypher_query = """ MATCH (n) WHERE n.embedding IS NOT NULL AND n.user_id = $user_id - WITH n, - round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / - (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity WHERE similarity >= $threshold MATCH (n)-[r]->(m) @@ -211,9 +211,9 @@ class MemoryGraph: UNION MATCH (n) WHERE n.embedding IS NOT NULL AND n.user_id = $user_id - WITH n, - round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / - (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * + WITH n, + round(reduce(dot = 0.0, i IN range(0, size(n.embedding)-1) | dot + n.embedding[i] * $n_embedding[i]) / + (sqrt(reduce(l2 = 0.0, i IN range(0, size(n.embedding)-1) | l2 + n.embedding[i] * n.embedding[i])) * sqrt(reduce(l2 = 0.0, i IN range(0, size($n_embedding)-1) | l2 + $n_embedding[i] * $n_embedding[i]))), 4) AS similarity WHERE similarity >= $threshold MATCH (m)-[r]->(n) diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 7d82c632..1aecb467 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -215,10 +215,14 @@ class Memory(MemoryBase): if self.version == "v1.1" and self.enable_graph: if filters["user_id"]: self.graph.user_id = filters["user_id"] + elif filters["agent_id"]: + self.graph.agent_id = filters["agent_id"] + elif filters["run_id"]: + self.graph.run_id = filters["run_id"] else: self.graph.user_id = "USER" data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) - self.graph.add(data, filters) + added_entities = self.graph.add(data, filters) return added_entities @@ -289,6 +293,8 @@ class Memory(MemoryBase): executor.submit(self.graph.get_all, filters, limit) if self.version == "v1.1" and self.enable_graph else None ) + concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) + all_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None @@ -379,6 +385,8 @@ class Memory(MemoryBase): else None ) + concurrent.futures.wait([future_memories, future_graph_entities] if future_graph_entities else [future_memories]) + original_memories = future_memories.result() graph_entities = future_graph_entities.result() if future_graph_entities else None diff --git a/pyproject.toml b/pyproject.toml index fe1f44ee..4fb8f76d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.16" +version = "0.1.17" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [