Added parallelization to memory method calls to reduce latency (#1803)
This commit is contained in:
@@ -15,6 +15,8 @@ from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||
import threading
|
||||
import concurrent
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
@@ -96,6 +98,18 @@ class Memory(MemoryBase):
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
|
||||
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
|
||||
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
def _add_to_vector_store(self, messages, metadata, filters):
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
||||
@@ -152,16 +166,15 @@ class Memory(MemoryBase):
|
||||
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
def _add_to_graph(self, messages, filters):
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
if user_id:
|
||||
self.graph.user_id = user_id
|
||||
if filters["user_id"]:
|
||||
self.graph.user_id = filters["user_id"]
|
||||
else:
|
||||
self.graph.user_id = "USER"
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg])
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||
added_entities = self.graph.add(data, filters)
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID.
|
||||
@@ -228,6 +241,30 @@ class Memory(MemoryBase):
|
||||
filters["run_id"] = run_id
|
||||
|
||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||
future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
|
||||
all_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"memories": all_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories": all_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return all_memories
|
||||
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||
@@ -259,22 +296,7 @@ class Memory(MemoryBase):
|
||||
}
|
||||
for mem in memories[0]
|
||||
]
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.get_all(filters)
|
||||
return {"memories": all_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : all_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return all_memories
|
||||
return all_memories
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
@@ -307,6 +329,30 @@ class Memory(MemoryBase):
|
||||
)
|
||||
|
||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
||||
future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||
|
||||
original_memories = future_memories.result()
|
||||
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"memories": original_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return original_memories
|
||||
|
||||
def _search_vector_store(self, query, filters, limit):
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
query=embeddings, limit=limit, filters=filters
|
||||
@@ -352,21 +398,7 @@ class Memory(MemoryBase):
|
||||
for mem in memories
|
||||
]
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.search(query, filters)
|
||||
return {"memories": original_memories, "entities": graph_entities}
|
||||
else:
|
||||
return {"memories" : original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return original_memories
|
||||
return original_memories
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user