From bf3ad37369bb91267ce33b90073e40abb46c3a95 Mon Sep 17 00:00:00 2001 From: Prateek Chhikara <46902268+prateekchhikara@users.noreply.github.com> Date: Tue, 3 Sep 2024 18:50:16 -0700 Subject: [PATCH] Added parallelization to memory method calls to reduce latency (#1803) --- .gitignore | 1 + .../memory/{main_graph.py => graph_memory.py} | 0 mem0/memory/main.py | 104 ++++++++++++------ pyproject.toml | 2 +- 4 files changed, 70 insertions(+), 37 deletions(-) rename mem0/memory/{main_graph.py => graph_memory.py} (100%) diff --git a/.gitignore b/.gitignore index bad8f40b..ca912601 100644 --- a/.gitignore +++ b/.gitignore @@ -184,3 +184,4 @@ notebooks/*.yaml eval/ qdrant_storage/ .crossnote +testing.ipynb diff --git a/mem0/memory/main_graph.py b/mem0/memory/graph_memory.py similarity index 100% rename from mem0/memory/main_graph.py rename to mem0/memory/graph_memory.py diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 0a918ba1..a0fe56c7 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -15,6 +15,8 @@ from mem0.memory.utils import get_fact_retrieval_messages, parse_messages from mem0.configs.prompts import get_update_memory_messages from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory from mem0.configs.base import MemoryItem, MemoryConfig +import threading +import concurrent # Setup user config setup_config() @@ -96,6 +98,18 @@ class Memory(MemoryBase): if isinstance(messages, str): messages = [{"role": "user", "content": messages}] + thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters)) + thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters)) + + thread1.start() + thread2.start() + + thread1.join() + thread2.join() + + return {"message": "ok"} + + def _add_to_vector_store(self, messages, metadata, filters): parsed_messages = parse_messages(messages) system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages) @@ -152,16 +166,15 @@ class Memory(MemoryBase): capture_event("mem0.add", self) + def _add_to_graph(self, messages, filters): if self.version == "v1.1" and self.enable_graph: - if user_id: - self.graph.user_id = user_id + if filters["user_id"]: + self.graph.user_id = filters["user_id"] else: self.graph.user_id = "USER" - data = "\n".join([msg["content"] for msg in messages if "content" in msg]) + data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) added_entities = self.graph.add(data, filters) - return {"message": "ok"} - def get(self, memory_id): """ Retrieve a memory by ID. @@ -228,6 +241,30 @@ class Memory(MemoryBase): filters["run_id"] = run_id capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit}) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future_memories = executor.submit(self._get_all_from_vector_store, filters, limit) + future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None + + all_memories = future_memories.result() + graph_entities = future_graph_entities.result() if future_graph_entities else None + + if self.version == "v1.1": + if self.enable_graph: + return {"memories": all_memories, "entities": graph_entities} + else: + return {"memories": all_memories} + else: + warnings.warn( + "The current get_all API output format is deprecated. " + "To use the latest format, set `api_version='v1.1'`. " + "The current format will be removed in mem0ai 1.1.0 and later versions.", + category=DeprecationWarning, + stacklevel=2 + ) + return all_memories + + def _get_all_from_vector_store(self, filters, limit): memories = self.vector_store.list(filters=filters, limit=limit) excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"} @@ -259,22 +296,7 @@ class Memory(MemoryBase): } for mem in memories[0] ] - - if self.version == "v1.1": - if self.enable_graph: - graph_entities = self.graph.get_all(filters) - return {"memories": all_memories, "entities": graph_entities} - else: - return {"memories" : all_memories} - else: - warnings.warn( - "The current get_all API output format is deprecated. " - "To use the latest format, set `api_version='v1.1'`. " - "The current format will be removed in mem0ai 1.1.0 and later versions.", - category=DeprecationWarning, - stacklevel=2 - ) - return all_memories + return all_memories def search( self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None @@ -307,6 +329,30 @@ class Memory(MemoryBase): ) capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version}) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future_memories = executor.submit(self._search_vector_store, query, filters, limit) + future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None + + original_memories = future_memories.result() + graph_entities = future_graph_entities.result() if future_graph_entities else None + + if self.version == "v1.1": + if self.enable_graph: + return {"memories": original_memories, "entities": graph_entities} + else: + return {"memories" : original_memories} + else: + warnings.warn( + "The current get_all API output format is deprecated. " + "To use the latest format, set `api_version='v1.1'`. " + "The current format will be removed in mem0ai 1.1.0 and later versions.", + category=DeprecationWarning, + stacklevel=2 + ) + return original_memories + + def _search_vector_store(self, query, filters, limit): embeddings = self.embedding_model.embed(query) memories = self.vector_store.search( query=embeddings, limit=limit, filters=filters @@ -352,21 +398,7 @@ class Memory(MemoryBase): for mem in memories ] - if self.version == "v1.1": - if self.enable_graph: - graph_entities = self.graph.search(query, filters) - return {"memories": original_memories, "entities": graph_entities} - else: - return {"memories" : original_memories} - else: - warnings.warn( - "The current get_all API output format is deprecated. " - "To use the latest format, set `api_version='v1.1'`. " - "The current format will be removed in mem0ai 1.1.0 and later versions.", - category=DeprecationWarning, - stacklevel=2 - ) - return original_memories + return original_memories def update(self, memory_id, data): """ diff --git a/pyproject.toml b/pyproject.toml index 43d9f93c..c70baed7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.8" +version = "0.1.9" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [