Added parallelization to memory method calls to reduce latency (#1803)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -184,3 +184,4 @@ notebooks/*.yaml
|
|||||||
eval/
|
eval/
|
||||||
qdrant_storage/
|
qdrant_storage/
|
||||||
.crossnote
|
.crossnote
|
||||||
|
testing.ipynb
|
||||||
|
|||||||
@@ -15,6 +15,8 @@ from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
|||||||
from mem0.configs.prompts import get_update_memory_messages
|
from mem0.configs.prompts import get_update_memory_messages
|
||||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||||
|
import threading
|
||||||
|
import concurrent
|
||||||
|
|
||||||
# Setup user config
|
# Setup user config
|
||||||
setup_config()
|
setup_config()
|
||||||
@@ -96,6 +98,18 @@ class Memory(MemoryBase):
|
|||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
|
||||||
|
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
|
||||||
|
|
||||||
|
thread1.start()
|
||||||
|
thread2.start()
|
||||||
|
|
||||||
|
thread1.join()
|
||||||
|
thread2.join()
|
||||||
|
|
||||||
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
def _add_to_vector_store(self, messages, metadata, filters):
|
||||||
parsed_messages = parse_messages(messages)
|
parsed_messages = parse_messages(messages)
|
||||||
|
|
||||||
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
||||||
@@ -152,16 +166,15 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
capture_event("mem0.add", self)
|
capture_event("mem0.add", self)
|
||||||
|
|
||||||
|
def _add_to_graph(self, messages, filters):
|
||||||
if self.version == "v1.1" and self.enable_graph:
|
if self.version == "v1.1" and self.enable_graph:
|
||||||
if user_id:
|
if filters["user_id"]:
|
||||||
self.graph.user_id = user_id
|
self.graph.user_id = filters["user_id"]
|
||||||
else:
|
else:
|
||||||
self.graph.user_id = "USER"
|
self.graph.user_id = "USER"
|
||||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg])
|
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||||
added_entities = self.graph.add(data, filters)
|
added_entities = self.graph.add(data, filters)
|
||||||
|
|
||||||
return {"message": "ok"}
|
|
||||||
|
|
||||||
def get(self, memory_id):
|
def get(self, memory_id):
|
||||||
"""
|
"""
|
||||||
Retrieve a memory by ID.
|
Retrieve a memory by ID.
|
||||||
@@ -228,6 +241,30 @@ class Memory(MemoryBase):
|
|||||||
filters["run_id"] = run_id
|
filters["run_id"] = run_id
|
||||||
|
|
||||||
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
|
||||||
|
future_graph_entities = executor.submit(self.graph.get_all, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||||
|
|
||||||
|
all_memories = future_memories.result()
|
||||||
|
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||||
|
|
||||||
|
if self.version == "v1.1":
|
||||||
|
if self.enable_graph:
|
||||||
|
return {"memories": all_memories, "entities": graph_entities}
|
||||||
|
else:
|
||||||
|
return {"memories": all_memories}
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"The current get_all API output format is deprecated. "
|
||||||
|
"To use the latest format, set `api_version='v1.1'`. "
|
||||||
|
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||||
|
category=DeprecationWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
return all_memories
|
||||||
|
|
||||||
|
def _get_all_from_vector_store(self, filters, limit):
|
||||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||||
|
|
||||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
|
||||||
@@ -259,21 +296,6 @@ class Memory(MemoryBase):
|
|||||||
}
|
}
|
||||||
for mem in memories[0]
|
for mem in memories[0]
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.version == "v1.1":
|
|
||||||
if self.enable_graph:
|
|
||||||
graph_entities = self.graph.get_all(filters)
|
|
||||||
return {"memories": all_memories, "entities": graph_entities}
|
|
||||||
else:
|
|
||||||
return {"memories" : all_memories}
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
"The current get_all API output format is deprecated. "
|
|
||||||
"To use the latest format, set `api_version='v1.1'`. "
|
|
||||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
|
||||||
category=DeprecationWarning,
|
|
||||||
stacklevel=2
|
|
||||||
)
|
|
||||||
return all_memories
|
return all_memories
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
@@ -307,6 +329,30 @@ class Memory(MemoryBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
|
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
|
||||||
|
future_graph_entities = executor.submit(self.graph.search, query, filters) if self.version == "v1.1" and self.enable_graph else None
|
||||||
|
|
||||||
|
original_memories = future_memories.result()
|
||||||
|
graph_entities = future_graph_entities.result() if future_graph_entities else None
|
||||||
|
|
||||||
|
if self.version == "v1.1":
|
||||||
|
if self.enable_graph:
|
||||||
|
return {"memories": original_memories, "entities": graph_entities}
|
||||||
|
else:
|
||||||
|
return {"memories" : original_memories}
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"The current get_all API output format is deprecated. "
|
||||||
|
"To use the latest format, set `api_version='v1.1'`. "
|
||||||
|
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||||
|
category=DeprecationWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
return original_memories
|
||||||
|
|
||||||
|
def _search_vector_store(self, query, filters, limit):
|
||||||
embeddings = self.embedding_model.embed(query)
|
embeddings = self.embedding_model.embed(query)
|
||||||
memories = self.vector_store.search(
|
memories = self.vector_store.search(
|
||||||
query=embeddings, limit=limit, filters=filters
|
query=embeddings, limit=limit, filters=filters
|
||||||
@@ -352,20 +398,6 @@ class Memory(MemoryBase):
|
|||||||
for mem in memories
|
for mem in memories
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.version == "v1.1":
|
|
||||||
if self.enable_graph:
|
|
||||||
graph_entities = self.graph.search(query, filters)
|
|
||||||
return {"memories": original_memories, "entities": graph_entities}
|
|
||||||
else:
|
|
||||||
return {"memories" : original_memories}
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
"The current get_all API output format is deprecated. "
|
|
||||||
"To use the latest format, set `api_version='v1.1'`. "
|
|
||||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
|
||||||
category=DeprecationWarning,
|
|
||||||
stacklevel=2
|
|
||||||
)
|
|
||||||
return original_memories
|
return original_memories
|
||||||
|
|
||||||
def update(self, memory_id, data):
|
def update(self, memory_id, data):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.8"
|
version = "0.1.9"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
Reference in New Issue
Block a user