Added parallelization to memory method calls to reduce latency (#1803)

This commit is contained in:
Prateek Chhikara
2024-09-03 18:50:16 -07:00
committed by GitHub
parent f21ca9b765
commit bf3ad37369
4 changed files with 70 additions and 37 deletions

View File

@@ -15,6 +15,8 @@ from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.configs.prompts import get_update_memory_messages
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
import threading
import concurrent
# Setup user config
setup_config()
@@ -96,6 +98,18 @@ class Memory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
thread1.start()
thread2.start()
thread1.join()
thread2.join()
return {"message": "ok"}
def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages)
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
@@ -152,16 +166,15 @@ class Memory(MemoryBase):
capture_event("mem0.add", self)
def _add_to_graph(self, messages, filters):
if self.version == "v1.1" and self.enable_graph:
if user_id:
self.graph.user_id = user_id
if filters["user_id"]:
self.graph.user_id = filters["user_id"]
else:
self.graph.user_id = "USER"
data = "\n".join([msg["content"] for msg in messages if "content" in msg])
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = self.graph.add(data, filters)
return {"message": "ok"}
def get(self, memory_id):
"""
Retrieve a memory by ID.
@@ -228,6 +241,30 @@ class Memory(MemoryBase):
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.version == "v1.1":
if self.enable_graph:
return {"memories": all_memories, "entities": graph_entities}
else:
return {"memories": all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return all_memories
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
@@ -259,22 +296,7 @@ class Memory(MemoryBase):
}
for mem in memories[0]
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.get_all(filters)
return {"memories": all_memories, "entities": graph_entities}
else:
return {"memories" : all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return all_memories
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
@@ -307,6 +329,30 @@ class Memory(MemoryBase):
)
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.version == "v1.1":
if self.enable_graph:
return {"memories": original_memories, "entities": graph_entities}
else:
return {"memories" : original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return original_memories
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
query=embeddings, limit=limit, filters=filters
@@ -352,21 +398,7 @@ class Memory(MemoryBase):
for mem in memories
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.search(query, filters)
return {"memories": original_memories, "entities": graph_entities}
else:
return {"memories" : original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return original_memories
return original_memories
def update(self, memory_id, data):
"""