Add async support (#2492)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import concurrent
|
||||
import hashlib
|
||||
import json
|
||||
@@ -196,8 +197,8 @@ class Memory(MemoryBase):
|
||||
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
if self.custom_fact_extraction_prompt:
|
||||
system_prompt = self.custom_fact_extraction_prompt
|
||||
if self.config.custom_fact_extraction_prompt:
|
||||
system_prompt = self.config.custom_fact_extraction_prompt
|
||||
user_prompt = f"Input:\n{parsed_messages}"
|
||||
else:
|
||||
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
||||
@@ -243,7 +244,7 @@ class Memory(MemoryBase):
|
||||
retrieved_old_memory[idx]["id"] = str(idx)
|
||||
|
||||
function_calling_prompt = get_update_memory_messages(
|
||||
retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt
|
||||
retrieved_old_memory, new_retrieved_facts, self.config.custom_update_memory_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -755,3 +756,773 @@ class Memory(MemoryBase):
|
||||
|
||||
def chat(self, query):
|
||||
raise NotImplementedError("Chat function not implemented yet.")
|
||||
|
||||
|
||||
class AsyncMemory(MemoryBase):
|
||||
def __init__(self, config: MemoryConfig = MemoryConfig()):
|
||||
self.config = config
|
||||
|
||||
self.embedding_model = EmbedderFactory.create(
|
||||
self.config.embedder.provider,
|
||||
self.config.embedder.config,
|
||||
self.config.vector_store.config,
|
||||
)
|
||||
self.vector_store = VectorStoreFactory.create(
|
||||
self.config.vector_store.provider, self.config.vector_store.config
|
||||
)
|
||||
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
self.collection_name = self.config.vector_store.config.collection_name
|
||||
self.api_version = self.config.version
|
||||
|
||||
self.enable_graph = False
|
||||
|
||||
if self.config.graph_store.config:
|
||||
from mem0.memory.graph_memory import MemoryGraph
|
||||
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
@classmethod
|
||||
async def from_config(cls, config_dict: Dict[str, Any]):
|
||||
try:
|
||||
config = cls._process_config(config_dict)
|
||||
config = MemoryConfig(**config_dict)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Configuration validation error: {e}")
|
||||
raise
|
||||
return cls(config)
|
||||
|
||||
@staticmethod
|
||||
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "graph_store" in config_dict:
|
||||
if "vector_store" not in config_dict and "embedder" in config_dict:
|
||||
config_dict["vector_store"] = {}
|
||||
config_dict["vector_store"]["config"] = {}
|
||||
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][
|
||||
"embedding_dims"
|
||||
]
|
||||
try:
|
||||
return config_dict
|
||||
except ValidationError as e:
|
||||
logger.error(f"Configuration validation error: {e}")
|
||||
raise
|
||||
|
||||
async def add(
|
||||
self,
|
||||
messages,
|
||||
user_id=None,
|
||||
agent_id=None,
|
||||
run_id=None,
|
||||
metadata=None,
|
||||
filters=None,
|
||||
infer=True,
|
||||
memory_type=None,
|
||||
prompt=None,
|
||||
llm=None,
|
||||
):
|
||||
"""
|
||||
Create a new memory asynchronously.
|
||||
|
||||
Args:
|
||||
messages (str or List[Dict[str, str]]): Messages to store in the memory.
|
||||
user_id (str, optional): ID of the user creating the memory. Defaults to None.
|
||||
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
|
||||
run_id (str, optional): ID of the run creating the memory. Defaults to None.
|
||||
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
infer (bool, optional): Whether to infer the memories. Defaults to True.
|
||||
memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
|
||||
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
|
||||
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the memory addition operation.
|
||||
result: dict of affected events with each dict has the following key:
|
||||
'memories': affected memories
|
||||
'graph': affected graph memories
|
||||
|
||||
'memories' and 'graph' is a dict, each with following subkeys:
|
||||
'add': added memory
|
||||
'update': updated memory
|
||||
'delete': deleted memory
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
filters = filters or {}
|
||||
if user_id:
|
||||
filters["user_id"] = metadata["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = metadata["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = metadata["run_id"] = run_id
|
||||
|
||||
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
|
||||
|
||||
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
||||
raise ValueError(
|
||||
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
||||
)
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
|
||||
results = await self._create_procedural_memory(messages, metadata=metadata, llm=llm, prompt=prompt)
|
||||
return results
|
||||
|
||||
if self.config.llm.config.get("enable_vision"):
|
||||
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
|
||||
else:
|
||||
messages = parse_vision_messages(messages)
|
||||
|
||||
# Run vector store and graph operations concurrently
|
||||
vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, metadata, filters, infer))
|
||||
graph_task = asyncio.create_task(self._add_to_graph(messages, filters))
|
||||
|
||||
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
|
||||
|
||||
if self.api_version == "v1.0":
|
||||
warnings.warn(
|
||||
"The current add API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return vector_store_result
|
||||
|
||||
if self.enable_graph:
|
||||
return {
|
||||
"results": vector_store_result,
|
||||
"relations": graph_result,
|
||||
}
|
||||
|
||||
return {"results": vector_store_result}
|
||||
|
||||
async def _add_to_vector_store(self, messages, metadata, filters, infer):
|
||||
if not infer:
|
||||
returned_memories = []
|
||||
for message in messages:
|
||||
if message["role"] != "system":
|
||||
message_embeddings = await asyncio.to_thread(self.embedding_model.embed, message["content"], "add")
|
||||
memory_id = await self._create_memory(message["content"], message_embeddings, metadata)
|
||||
returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"})
|
||||
return returned_memories
|
||||
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
if self.config.custom_fact_extraction_prompt:
|
||||
system_prompt = self.config.custom_fact_extraction_prompt
|
||||
user_prompt = f"Input:\n{parsed_messages}"
|
||||
else:
|
||||
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
||||
|
||||
response = await asyncio.to_thread(
|
||||
self.llm.generate_response,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
try:
|
||||
response = remove_code_blocks(response)
|
||||
new_retrieved_facts = json.loads(response)["facts"]
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_retrieved_facts: {e}")
|
||||
new_retrieved_facts = []
|
||||
|
||||
retrieved_old_memory = []
|
||||
new_message_embeddings = {}
|
||||
|
||||
# Process all facts concurrently
|
||||
async def process_fact(new_mem):
|
||||
messages_embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem, "add")
|
||||
new_message_embeddings[new_mem] = messages_embeddings
|
||||
existing_memories = await asyncio.to_thread(
|
||||
self.vector_store.search,
|
||||
query=new_mem,
|
||||
vectors=messages_embeddings,
|
||||
limit=5,
|
||||
filters=filters,
|
||||
)
|
||||
return [(mem.id, mem.payload["data"]) for mem in existing_memories]
|
||||
|
||||
fact_tasks = [process_fact(fact) for fact in new_retrieved_facts]
|
||||
fact_results = await asyncio.gather(*fact_tasks)
|
||||
|
||||
# Flatten results and build retrieved_old_memory
|
||||
for result in fact_results:
|
||||
for mem_id, mem_data in result:
|
||||
retrieved_old_memory.append({"id": mem_id, "text": mem_data})
|
||||
|
||||
unique_data = {}
|
||||
for item in retrieved_old_memory:
|
||||
unique_data[item["id"]] = item
|
||||
retrieved_old_memory = list(unique_data.values())
|
||||
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
|
||||
|
||||
# mapping UUIDs with integers for handling UUID hallucinations
|
||||
temp_uuid_mapping = {}
|
||||
for idx, item in enumerate(retrieved_old_memory):
|
||||
temp_uuid_mapping[str(idx)] = item["id"]
|
||||
retrieved_old_memory[idx]["id"] = str(idx)
|
||||
|
||||
function_calling_prompt = get_update_memory_messages(
|
||||
retrieved_old_memory, new_retrieved_facts, self.config.custom_update_memory_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
new_memories_with_actions = await asyncio.to_thread(
|
||||
self.llm.generate_response,
|
||||
messages=[{"role": "user", "content": function_calling_prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
new_memories_with_actions = []
|
||||
|
||||
try:
|
||||
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
|
||||
new_memories_with_actions = json.loads(new_memories_with_actions)
|
||||
except Exception as e:
|
||||
logging.error(f"Invalid JSON response: {e}")
|
||||
new_memories_with_actions = []
|
||||
|
||||
returned_memories = []
|
||||
try:
|
||||
memory_tasks = []
|
||||
for resp in new_memories_with_actions.get("memory", []):
|
||||
logging.info(resp)
|
||||
try:
|
||||
if not resp.get("text"):
|
||||
logging.info("Skipping memory entry because of empty `text` field.")
|
||||
continue
|
||||
elif resp.get("event") == "ADD":
|
||||
task = asyncio.create_task(self._create_memory(
|
||||
data=resp.get("text"),
|
||||
existing_embeddings=new_message_embeddings,
|
||||
metadata=metadata
|
||||
))
|
||||
memory_tasks.append((task, resp, "ADD", None))
|
||||
elif resp.get("event") == "UPDATE":
|
||||
task = asyncio.create_task(self._update_memory(
|
||||
memory_id=temp_uuid_mapping[resp["id"]],
|
||||
data=resp.get("text"),
|
||||
existing_embeddings=new_message_embeddings,
|
||||
metadata=metadata,
|
||||
))
|
||||
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
|
||||
elif resp.get("event") == "DELETE":
|
||||
task = asyncio.create_task(self._delete_memory(
|
||||
memory_id=temp_uuid_mapping[resp.get("id")]
|
||||
))
|
||||
memory_tasks.append((task, resp, "DELETE", temp_uuid_mapping[resp["id"]]))
|
||||
elif resp.get("event") == "NONE":
|
||||
logging.info("NOOP for Memory.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
|
||||
# Wait for all memory operations to complete
|
||||
for task, resp, event_type, mem_id in memory_tasks:
|
||||
try:
|
||||
result_id = await task
|
||||
if event_type == "ADD":
|
||||
returned_memories.append({
|
||||
"id": result_id,
|
||||
"memory": resp.get("text"),
|
||||
"event": resp.get("event"),
|
||||
})
|
||||
elif event_type == "UPDATE":
|
||||
returned_memories.append({
|
||||
"id": mem_id,
|
||||
"memory": resp.get("text"),
|
||||
"event": resp.get("event"),
|
||||
"previous_memory": resp.get("old_memory"),
|
||||
})
|
||||
elif event_type == "DELETE":
|
||||
returned_memories.append({
|
||||
"id": mem_id,
|
||||
"memory": resp.get("text"),
|
||||
"event": resp.get("event"),
|
||||
})
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing memory task: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
|
||||
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
|
||||
|
||||
return returned_memories
|
||||
|
||||
async def _add_to_graph(self, messages, filters):
|
||||
added_entities = []
|
||||
if self.enable_graph:
|
||||
if filters.get("user_id") is None:
|
||||
filters["user_id"] = "user"
|
||||
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||
added_entities = await asyncio.to_thread(self.graph.add, data, filters)
|
||||
|
||||
return added_entities
|
||||
|
||||
async def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID asynchronously.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved memory.
|
||||
"""
|
||||
capture_event("mem0.get", self, {"memory_id": memory_id})
|
||||
memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
|
||||
|
||||
# Prepare base memory item
|
||||
memory_item = MemoryItem(
|
||||
id=memory.id,
|
||||
memory=memory.payload["data"],
|
||||
hash=memory.payload.get("hash"),
|
||||
created_at=memory.payload.get("created_at"),
|
||||
updated_at=memory.payload.get("updated_at"),
|
||||
).model_dump(exclude={"score"})
|
||||
|
||||
# Add metadata if there are additional keys
|
||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"}
|
||||
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
|
||||
if additional_metadata:
|
||||
memory_item["metadata"] = additional_metadata
|
||||
|
||||
result = {**memory_item, **filters}
|
||||
|
||||
return result
|
||||
|
||||
async def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||
"""
|
||||
List all memories asynchronously.
|
||||
|
||||
Returns:
|
||||
list: List of all memories.
|
||||
"""
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = run_id
|
||||
|
||||
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
|
||||
|
||||
# Run vector store and graph operations concurrently
|
||||
vector_store_task = asyncio.create_task(self._get_all_from_vector_store(filters, limit))
|
||||
|
||||
if self.enable_graph:
|
||||
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.get_all, filters, limit))
|
||||
all_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
|
||||
else:
|
||||
all_memories = await vector_store_task
|
||||
graph_entities = None
|
||||
|
||||
if self.enable_graph:
|
||||
return {"results": all_memories, "relations": graph_entities}
|
||||
|
||||
if self.api_version == "v1.0":
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return all_memories
|
||||
else:
|
||||
return {"results": all_memories}
|
||||
|
||||
async def _get_all_from_vector_store(self, filters, limit):
|
||||
memories = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
|
||||
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"hash",
|
||||
"data",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"id",
|
||||
}
|
||||
all_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
memory=mem.payload["data"],
|
||||
hash=mem.payload.get("hash"),
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
).model_dump(exclude={"score"}),
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**(
|
||||
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for mem in memories[0]
|
||||
]
|
||||
return all_memories
|
||||
|
||||
async def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
|
||||
"""
|
||||
Search for memories asynchronously.
|
||||
|
||||
Args:
|
||||
query (str): Query to search for.
|
||||
user_id (str, optional): ID of the user to search for. Defaults to None.
|
||||
agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
||||
run_id (str, optional): ID of the run to search for. Defaults to None.
|
||||
limit (int, optional): Limit the number of results. Defaults to 100.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: List of search results.
|
||||
"""
|
||||
filters = filters or {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = run_id
|
||||
|
||||
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
|
||||
|
||||
capture_event(
|
||||
"mem0.search",
|
||||
self,
|
||||
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
|
||||
)
|
||||
|
||||
# Run vector store and graph operations concurrently
|
||||
vector_store_task = asyncio.create_task(self._search_vector_store(query, filters, limit))
|
||||
|
||||
if self.enable_graph:
|
||||
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, filters, limit))
|
||||
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
|
||||
else:
|
||||
original_memories = await vector_store_task
|
||||
graph_entities = None
|
||||
|
||||
if self.enable_graph:
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
|
||||
if self.api_version == "v1.0":
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return original_memories
|
||||
else:
|
||||
return {"results": original_memories}
|
||||
|
||||
async def _search_vector_store(self, query, filters, limit):
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
|
||||
memories = await asyncio.to_thread(
|
||||
self.vector_store.search,
|
||||
query=query,
|
||||
vectors=embeddings,
|
||||
limit=limit,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
excluded_keys = {
|
||||
"user_id",
|
||||
"agent_id",
|
||||
"run_id",
|
||||
"hash",
|
||||
"data",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"id",
|
||||
}
|
||||
|
||||
original_memories = [
|
||||
{
|
||||
**MemoryItem(
|
||||
id=mem.id,
|
||||
memory=mem.payload["data"],
|
||||
hash=mem.payload.get("hash"),
|
||||
created_at=mem.payload.get("created_at"),
|
||||
updated_at=mem.payload.get("updated_at"),
|
||||
score=mem.score,
|
||||
).model_dump(),
|
||||
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
|
||||
**(
|
||||
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
|
||||
if any(k for k in mem.payload if k not in excluded_keys)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for mem in memories
|
||||
]
|
||||
|
||||
return original_memories
|
||||
|
||||
async def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID asynchronously.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to update.
|
||||
data (dict): Data to update the memory with.
|
||||
|
||||
Returns:
|
||||
dict: Updated memory.
|
||||
"""
|
||||
capture_event("mem0.update", self, {"memory_id": memory_id})
|
||||
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
|
||||
existing_embeddings = {data: embeddings}
|
||||
|
||||
await self._update_memory(memory_id, data, existing_embeddings)
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
async def delete(self, memory_id):
|
||||
"""
|
||||
Delete a memory by ID asynchronously.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to delete.
|
||||
"""
|
||||
capture_event("mem0.delete", self, {"memory_id": memory_id})
|
||||
await self._delete_memory(memory_id)
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
async def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
"""
|
||||
Delete all memories asynchronously.
|
||||
|
||||
Args:
|
||||
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
|
||||
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
|
||||
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
|
||||
"""
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if run_id:
|
||||
filters["run_id"] = run_id
|
||||
|
||||
if not filters:
|
||||
raise ValueError(
|
||||
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
|
||||
)
|
||||
|
||||
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
|
||||
memories = await asyncio.to_thread(self.vector_store.list, filters=filters)
|
||||
|
||||
delete_tasks = []
|
||||
for memory in memories[0]:
|
||||
delete_tasks.append(self._delete_memory(memory.id))
|
||||
|
||||
await asyncio.gather(*delete_tasks)
|
||||
|
||||
logger.info(f"Deleted {len(memories[0])} memories")
|
||||
|
||||
if self.enable_graph:
|
||||
await asyncio.to_thread(self.graph.delete_all, filters)
|
||||
|
||||
return {"message": "Memories deleted successfully!"}
|
||||
|
||||
async def history(self, memory_id):
|
||||
"""
|
||||
Get the history of changes for a memory by ID asynchronously.
|
||||
|
||||
Args:
|
||||
memory_id (str): ID of the memory to get history for.
|
||||
|
||||
Returns:
|
||||
list: List of changes for the memory.
|
||||
"""
|
||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||
return await asyncio.to_thread(self.db.get_history, memory_id)
|
||||
|
||||
async def _create_memory(self, data, existing_embeddings, metadata=None):
|
||||
logging.debug(f"Creating memory with {data=}")
|
||||
if data in existing_embeddings:
|
||||
embeddings = existing_embeddings[data]
|
||||
else:
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, memory_action="add")
|
||||
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata = metadata or {}
|
||||
metadata["data"] = data
|
||||
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
||||
metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.vector_store.insert,
|
||||
vectors=[embeddings],
|
||||
ids=[memory_id],
|
||||
payloads=[metadata],
|
||||
)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.db.add_history,
|
||||
memory_id,
|
||||
None,
|
||||
data,
|
||||
"ADD",
|
||||
created_at=metadata["created_at"]
|
||||
)
|
||||
|
||||
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
|
||||
return memory_id
|
||||
|
||||
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
|
||||
"""
|
||||
Create a procedural memory asynchronously
|
||||
|
||||
Args:
|
||||
messages (list): List of messages to create a procedural memory from.
|
||||
metadata (dict): Metadata to create a procedural memory from.
|
||||
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
|
||||
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
|
||||
"""
|
||||
try:
|
||||
from langchain_core.messages.utils import (
|
||||
convert_to_messages, # type: ignore
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory.")
|
||||
raise
|
||||
|
||||
logger.info("Creating procedural memory")
|
||||
|
||||
parsed_messages = [
|
||||
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
|
||||
*messages,
|
||||
{"role": "user", "content": "Create procedural memory of the above conversation."},
|
||||
]
|
||||
|
||||
try:
|
||||
if llm is not None:
|
||||
parsed_messages = convert_to_messages(parsed_messages)
|
||||
response = await asyncio.to_thread(llm.invoke, input=parsed_messages)
|
||||
procedural_memory = response.content
|
||||
else:
|
||||
procedural_memory = await asyncio.to_thread(
|
||||
self.llm.generate_response,
|
||||
messages=parsed_messages
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating procedural memory summary: {e}")
|
||||
raise
|
||||
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata cannot be done for procedural memory.")
|
||||
|
||||
metadata["memory_type"] = MemoryType.PROCEDURAL.value
|
||||
# Generate embeddings for the summary
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, procedural_memory, memory_action="add")
|
||||
# Create the memory
|
||||
memory_id = await self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
|
||||
capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id})
|
||||
|
||||
# Return results in the same format as add()
|
||||
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
|
||||
|
||||
return result
|
||||
|
||||
async def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
|
||||
logger.info(f"Updating memory with {data=}")
|
||||
|
||||
try:
|
||||
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
|
||||
except Exception:
|
||||
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
|
||||
|
||||
prev_value = existing_memory.payload.get("data")
|
||||
|
||||
new_metadata = metadata or {}
|
||||
new_metadata["data"] = data
|
||||
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
|
||||
new_metadata["created_at"] = existing_memory.payload.get("created_at")
|
||||
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
|
||||
|
||||
if "user_id" in existing_memory.payload:
|
||||
new_metadata["user_id"] = existing_memory.payload["user_id"]
|
||||
if "agent_id" in existing_memory.payload:
|
||||
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
|
||||
if "run_id" in existing_memory.payload:
|
||||
new_metadata["run_id"] = existing_memory.payload["run_id"]
|
||||
|
||||
if data in existing_embeddings:
|
||||
embeddings = existing_embeddings[data]
|
||||
else:
|
||||
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.vector_store.update,
|
||||
vector_id=memory_id,
|
||||
vector=embeddings,
|
||||
payload=new_metadata,
|
||||
)
|
||||
|
||||
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.db.add_history,
|
||||
memory_id,
|
||||
prev_value,
|
||||
data,
|
||||
"UPDATE",
|
||||
created_at=new_metadata["created_at"],
|
||||
updated_at=new_metadata["updated_at"],
|
||||
)
|
||||
|
||||
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
|
||||
return memory_id
|
||||
|
||||
async def _delete_memory(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
|
||||
prev_value = existing_memory.payload["data"]
|
||||
|
||||
await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id)
|
||||
await asyncio.to_thread(self.db.add_history, memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||
|
||||
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
|
||||
return memory_id
|
||||
|
||||
async def reset(self):
|
||||
"""
|
||||
Reset the memory store asynchronously.
|
||||
"""
|
||||
logger.warning("Resetting all memories")
|
||||
await asyncio.to_thread(self.vector_store.delete_col)
|
||||
self.vector_store = VectorStoreFactory.create(
|
||||
self.config.vector_store.provider, self.config.vector_store.config
|
||||
)
|
||||
await asyncio.to_thread(self.db.reset)
|
||||
capture_event("mem0.reset", self)
|
||||
|
||||
async def chat(self, query):
|
||||
raise NotImplementedError("Chat function not implemented yet.")
|
||||
|
||||
Reference in New Issue
Block a user