Files
t6_mem0/mem0/memory/main.py
2025-04-11 13:36:26 -07:00

1523 lines
59 KiB
Python

import asyncio
import concurrent
import hashlib
import json
import logging
import uuid
import warnings
from datetime import datetime
from typing import Any, Dict
import pytz
from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.enums import MemoryType
from mem0.configs.prompts import (
PROCEDURAL_MEMORY_SYSTEM_PROMPT,
get_update_memory_messages,
)
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import (
get_fact_retrieval_messages,
parse_messages,
parse_vision_messages,
remove_code_blocks,
)
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
# Setup user config
setup_config()
logger = logging.getLogger(__name__)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.custom_fact_extraction_prompt = self.config.custom_fact_extraction_prompt
self.custom_update_memory_prompt = self.config.custom_update_memory_prompt
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
self.config.vector_store.config,
)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.api_version = self.config.version
self.enable_graph = False
if self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
config = cls._process_config(config_dict)
config = MemoryConfig(**config_dict)
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
return cls(config)
@staticmethod
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "graph_store" in config_dict:
if "vector_store" not in config_dict and "embedder" in config_dict:
config_dict["vector_store"] = {}
config_dict["vector_store"]["config"] = {}
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][
"embedding_dims"
]
try:
return config_dict
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
def add(
self,
messages,
user_id=None,
agent_id=None,
run_id=None,
metadata=None,
filters=None,
infer=True,
memory_type=None,
prompt=None,
):
"""
Create a new memory.
Args:
messages (str or List[Dict[str, str]]): Messages to store in the memory.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
infer (bool, optional): Whether to infer the memories. Defaults to True.
memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
Returns:
dict: A dictionary containing the result of the memory addition operation.
result: dict of affected events with each dict has the following key:
'memories': affected memories
'graph': affected graph memories
'memories' and 'graph' is a dict, each with following subkeys:
'add': added memory
'update': updated memory
'delete': deleted memory
"""
if metadata is None:
metadata = {}
filters = filters or {}
if user_id:
filters["user_id"] = metadata["user_id"] = user_id
if agent_id:
filters["agent_id"] = metadata["agent_id"] = agent_id
if run_id:
filters["run_id"] = metadata["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
)
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = self._create_procedural_memory(messages, metadata=metadata, prompt=prompt)
return results
if self.config.llm.config.get("enable_vision"):
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
else:
messages = parse_vision_messages(messages)
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, infer)
future2 = executor.submit(self._add_to_graph, messages, filters)
concurrent.futures.wait([future1, future2])
vector_store_result = future1.result()
graph_result = future2.result()
if self.api_version == "v1.0":
warnings.warn(
"The current add API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return vector_store_result
if self.enable_graph:
return {
"results": vector_store_result,
"relations": graph_result,
}
return {"results": vector_store_result}
def _add_to_vector_store(self, messages, metadata, filters, infer):
if not infer:
returned_memories = []
for message in messages:
if message["role"] != "system":
message_embeddings = self.embedding_model.embed(message["content"], "add")
memory_id = self._create_memory(message["content"], message_embeddings, metadata)
returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"})
return returned_memories
parsed_messages = parse_messages(messages)
if self.config.custom_fact_extraction_prompt:
system_prompt = self.config.custom_fact_extraction_prompt
user_prompt = f"Input:\n{parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
try:
response = remove_code_blocks(response)
new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
query=new_mem,
vectors=messages_embeddings,
limit=5,
filters=filters,
)
for mem in existing_memories:
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
unique_data = {}
for item in retrieved_old_memory:
unique_data[item["id"]] = item
retrieved_old_memory = list(unique_data.values())
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
# mapping UUIDs with integers for handling UUID hallucinations
temp_uuid_mapping = {}
for idx, item in enumerate(retrieved_old_memory):
temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx)
function_calling_prompt = get_update_memory_messages(
retrieved_old_memory, new_retrieved_facts, self.config.custom_update_memory_prompt
)
try:
new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
new_memories_with_actions = []
try:
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
new_memories_with_actions = json.loads(new_memories_with_actions)
except Exception as e:
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = []
returned_memories = []
try:
for resp in new_memories_with_actions.get("memory", []):
logging.info(resp)
try:
if not resp.get("text"):
logging.info("Skipping memory entry because of empty `text` field.")
continue
elif resp.get("event") == "ADD":
memory_id = self._create_memory(
data=resp.get("text"),
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
returned_memories.append(
{
"id": memory_id,
"memory": resp.get("text"),
"event": resp.get("event"),
}
)
elif resp.get("event") == "UPDATE":
self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]],
data=resp.get("text"),
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": resp.get("text"),
"event": resp.get("event"),
"previous_memory": resp.get("old_memory"),
}
)
elif resp.get("event") == "DELETE":
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": resp.get("text"),
"event": resp.get("event"),
}
)
elif resp.get("event") == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
capture_event(
"mem0.add",
self,
{"version": self.api_version, "keys": list(filters.keys())},
)
return returned_memories
def _add_to_graph(self, messages, filters):
added_entities = []
if self.enable_graph:
if filters.get("user_id") is None:
filters["user_id"] = "user"
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = self.graph.add(data, filters)
return added_entities
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("mem0.get", self, {"memory_id": memory_id})
memory = self.vector_store.get(vector_id=memory_id)
if not memory:
return None
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
id=memory.id,
memory=memory.payload["data"],
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
"id",
}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
result = {**memory_item, **filters}
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
Returns:
list: List of all memories.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = executor.submit(self.graph.get_all, filters, limit) if self.enable_graph else None
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": all_memories, "relations": graph_entities}
if self.api_version == "v1.0":
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return all_memories
else:
return {"results": all_memories}
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
"id",
}
all_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories[0]
]
return all_memories
def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
"""
Search for memories.
Args:
query (str): Query to search for.
user_id (str, optional): ID of the user to search for. Defaults to None.
agent_id (str, optional): ID of the agent to search for. Defaults to None.
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: List of search results.
"""
filters = filters or {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
capture_event(
"mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
)
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters, limit) if self.enable_graph else None
)
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
if self.api_version == "v1.0":
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return original_memories
else:
return {"results": original_memories}
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query, "search")
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
"id",
}
original_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories
]
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (dict): Data to update the memory with.
Returns:
dict: Updated memory.
"""
capture_event("mem0.update", self, {"memory_id": memory_id})
existing_embeddings = {data: self.embedding_model.embed(data, "update")}
self._update_memory(memory_id, data, existing_embeddings)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)
logger.info(f"Deleted {len(memories)} memories")
if self.enable_graph:
self.graph.delete_all(filters)
return {"message": "Memories deleted successfully!"}
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory(self, data, existing_embeddings, metadata=None):
logging.debug(f"Creating memory with {data=}")
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, memory_action="add")
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
self.vector_store.insert(
vectors=[embeddings],
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
return memory_id
def _create_procedural_memory(self, messages, metadata=None, prompt=None):
"""
Create a procedural memory
Args:
messages (list): List of messages to create a procedural memory from.
metadata (dict): Metadata to create a procedural memory from.
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
"""
logger.info("Creating procedural memory")
parsed_messages = [
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
*messages,
{
"role": "user",
"content": "Create procedural memory of the above conversation.",
},
]
try:
procedural_memory = self.llm.generate_response(messages=parsed_messages)
except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}")
raise
if metadata is None:
raise ValueError("Metadata cannot be done for procedural memory.")
metadata["memory_type"] = MemoryType.PROCEDURAL.value
# Generate embeddings for the summary
embeddings = self.embedding_model.embed(procedural_memory, memory_action="add")
# Create the memory
memory_id = self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id})
# Return results in the same format as add()
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
return result
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
try:
existing_memory = self.vector_store.get(vector_id=memory_id)
except Exception:
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
return memory_id
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id
def reset(self):
"""
Reset the memory store.
"""
logger.warning("Resetting all memories")
self.vector_store.delete_col()
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.db.reset()
capture_event("mem0.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")
class AsyncMemory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider,
self.config.embedder.config,
self.config.vector_store.config,
)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.api_version = self.config.version
self.enable_graph = False
if self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("async_mem0.init", self)
@classmethod
async def from_config(cls, config_dict: Dict[str, Any]):
try:
config = cls._process_config(config_dict)
config = MemoryConfig(**config_dict)
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
return cls(config)
@staticmethod
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "graph_store" in config_dict:
if "vector_store" not in config_dict and "embedder" in config_dict:
config_dict["vector_store"] = {}
config_dict["vector_store"]["config"] = {}
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][
"embedding_dims"
]
try:
return config_dict
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
async def add(
self,
messages,
user_id=None,
agent_id=None,
run_id=None,
metadata=None,
filters=None,
infer=True,
memory_type=None,
prompt=None,
llm=None,
):
"""
Create a new memory asynchronously.
Args:
messages (str or List[Dict[str, str]]): Messages to store in the memory.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
infer (bool, optional): Whether to infer the memories. Defaults to True.
memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
Returns:
dict: A dictionary containing the result of the memory addition operation.
result: dict of affected events with each dict has the following key:
'memories': affected memories
'graph': affected graph memories
'memories' and 'graph' is a dict, each with following subkeys:
'add': added memory
'update': updated memory
'delete': deleted memory
"""
if metadata is None:
metadata = {}
filters = filters or {}
if user_id:
filters["user_id"] = metadata["user_id"] = user_id
if agent_id:
filters["agent_id"] = metadata["agent_id"] = agent_id
if run_id:
filters["run_id"] = metadata["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
)
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = await self._create_procedural_memory(messages, metadata=metadata, llm=llm, prompt=prompt)
return results
if self.config.llm.config.get("enable_vision"):
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
else:
messages = parse_vision_messages(messages)
# Run vector store and graph operations concurrently
vector_store_task = asyncio.create_task(self._add_to_vector_store(messages, metadata, filters, infer))
graph_task = asyncio.create_task(self._add_to_graph(messages, filters))
vector_store_result, graph_result = await asyncio.gather(vector_store_task, graph_task)
if self.api_version == "v1.0":
warnings.warn(
"The current add API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return vector_store_result
if self.enable_graph:
return {
"results": vector_store_result,
"relations": graph_result,
}
return {"results": vector_store_result}
async def _add_to_vector_store(self, messages, metadata, filters, infer):
if not infer:
returned_memories = []
for message in messages:
if message["role"] != "system":
message_embeddings = await asyncio.to_thread(self.embedding_model.embed, message["content"], "add")
memory_id = await self._create_memory(message["content"], message_embeddings, metadata)
returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"})
return returned_memories
parsed_messages = parse_messages(messages)
if self.config.custom_fact_extraction_prompt:
system_prompt = self.config.custom_fact_extraction_prompt
user_prompt = f"Input:\n{parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = await asyncio.to_thread(
self.llm.generate_response,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
try:
response = remove_code_blocks(response)
new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = []
new_message_embeddings = {}
# Process all facts concurrently
async def process_fact(new_mem):
messages_embeddings = await asyncio.to_thread(self.embedding_model.embed, new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = await asyncio.to_thread(
self.vector_store.search,
query=new_mem,
vectors=messages_embeddings,
limit=5,
filters=filters,
)
return [(mem.id, mem.payload["data"]) for mem in existing_memories]
fact_tasks = [process_fact(fact) for fact in new_retrieved_facts]
fact_results = await asyncio.gather(*fact_tasks)
# Flatten results and build retrieved_old_memory
for result in fact_results:
for mem_id, mem_data in result:
retrieved_old_memory.append({"id": mem_id, "text": mem_data})
unique_data = {}
for item in retrieved_old_memory:
unique_data[item["id"]] = item
retrieved_old_memory = list(unique_data.values())
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
# mapping UUIDs with integers for handling UUID hallucinations
temp_uuid_mapping = {}
for idx, item in enumerate(retrieved_old_memory):
temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx)
function_calling_prompt = get_update_memory_messages(
retrieved_old_memory, new_retrieved_facts, self.config.custom_update_memory_prompt
)
try:
new_memories_with_actions = await asyncio.to_thread(
self.llm.generate_response,
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
new_memories_with_actions = []
try:
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
new_memories_with_actions = json.loads(new_memories_with_actions)
except Exception as e:
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = []
returned_memories = []
try:
memory_tasks = []
for resp in new_memories_with_actions.get("memory", []):
logging.info(resp)
try:
if not resp.get("text"):
logging.info("Skipping memory entry because of empty `text` field.")
continue
elif resp.get("event") == "ADD":
task = asyncio.create_task(
self._create_memory(
data=resp.get("text"), existing_embeddings=new_message_embeddings, metadata=metadata
)
)
memory_tasks.append((task, resp, "ADD", None))
elif resp.get("event") == "UPDATE":
task = asyncio.create_task(
self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]],
data=resp.get("text"),
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
)
memory_tasks.append((task, resp, "UPDATE", temp_uuid_mapping[resp["id"]]))
elif resp.get("event") == "DELETE":
task = asyncio.create_task(self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")]))
memory_tasks.append((task, resp, "DELETE", temp_uuid_mapping[resp["id"]]))
elif resp.get("event") == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
# Wait for all memory operations to complete
for task, resp, event_type, mem_id in memory_tasks:
try:
result_id = await task
if event_type == "ADD":
returned_memories.append(
{
"id": result_id,
"memory": resp.get("text"),
"event": resp.get("event"),
}
)
elif event_type == "UPDATE":
returned_memories.append(
{
"id": mem_id,
"memory": resp.get("text"),
"event": resp.get("event"),
"previous_memory": resp.get("old_memory"),
}
)
elif event_type == "DELETE":
returned_memories.append(
{
"id": mem_id,
"memory": resp.get("text"),
"event": resp.get("event"),
}
)
except Exception as e:
logging.error(f"Error processing memory task: {e}")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
capture_event("async_mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
return returned_memories
async def _add_to_graph(self, messages, filters):
added_entities = []
if self.enable_graph:
if filters.get("user_id") is None:
filters["user_id"] = "user"
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = await asyncio.to_thread(self.graph.add, data, filters)
return added_entities
async def get(self, memory_id):
"""
Retrieve a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("async_mem0.get", self, {"memory_id": memory_id})
memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
if not memory:
return None
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
id=memory.id,
memory=memory.payload["data"],
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
result = {**memory_item, **filters}
return result
async def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories asynchronously.
Returns:
list: List of all memories.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
capture_event("async_mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
# Run vector store and graph operations concurrently
vector_store_task = asyncio.create_task(self._get_all_from_vector_store(filters, limit))
if self.enable_graph:
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.get_all, filters, limit))
all_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
else:
all_memories = await vector_store_task
graph_entities = None
if self.enable_graph:
return {"results": all_memories, "relations": graph_entities}
if self.api_version == "v1.0":
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return all_memories
else:
return {"results": all_memories}
async def _get_all_from_vector_store(self, filters, limit):
memories = await asyncio.to_thread(self.vector_store.list, filters=filters, limit=limit)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
"id",
}
all_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories[0]
]
return all_memories
async def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
"""
Search for memories asynchronously.
Args:
query (str): Query to search for.
user_id (str, optional): ID of the user to search for. Defaults to None.
agent_id (str, optional): ID of the agent to search for. Defaults to None.
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: List of search results.
"""
filters = filters or {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
capture_event(
"async_mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
)
# Run vector store and graph operations concurrently
vector_store_task = asyncio.create_task(self._search_vector_store(query, filters, limit))
if self.enable_graph:
graph_task = asyncio.create_task(asyncio.to_thread(self.graph.search, query, filters, limit))
original_memories, graph_entities = await asyncio.gather(vector_store_task, graph_task)
else:
original_memories = await vector_store_task
graph_entities = None
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
if self.api_version == "v1.0":
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return original_memories
else:
return {"results": original_memories}
async def _search_vector_store(self, query, filters, limit):
embeddings = await asyncio.to_thread(self.embedding_model.embed, query, "search")
memories = await asyncio.to_thread(
self.vector_store.search, query=query, vectors=embeddings, limit=limit, filters=filters
)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
"id",
}
original_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories
]
return original_memories
async def update(self, memory_id, data):
"""
Update a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to update.
data (dict): Data to update the memory with.
Returns:
dict: Updated memory.
"""
capture_event("async_mem0.update", self, {"memory_id": memory_id})
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
existing_embeddings = {data: embeddings}
await self._update_memory(memory_id, data, existing_embeddings)
return {"message": "Memory updated successfully!"}
async def delete(self, memory_id):
"""
Delete a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("async_mem0.delete", self, {"memory_id": memory_id})
await self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
async def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories asynchronously.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("async_mem0.delete_all", self, {"keys": list(filters.keys())})
memories = await asyncio.to_thread(self.vector_store.list, filters=filters)
delete_tasks = []
for memory in memories[0]:
delete_tasks.append(self._delete_memory(memory.id))
await asyncio.gather(*delete_tasks)
logger.info(f"Deleted {len(memories[0])} memories")
if self.enable_graph:
await asyncio.to_thread(self.graph.delete_all, filters)
return {"message": "Memories deleted successfully!"}
async def history(self, memory_id):
"""
Get the history of changes for a memory by ID asynchronously.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("async_mem0.history", self, {"memory_id": memory_id})
return await asyncio.to_thread(self.db.get_history, memory_id)
async def _create_memory(self, data, existing_embeddings, metadata=None):
logging.debug(f"Creating memory with {data=}")
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, memory_action="add")
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
await asyncio.to_thread(
self.vector_store.insert,
vectors=[embeddings],
ids=[memory_id],
payloads=[metadata],
)
await asyncio.to_thread(self.db.add_history, memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("async_mem0._create_memory", self, {"memory_id": memory_id})
return memory_id
async def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
"""
Create a procedural memory asynchronously
Args:
messages (list): List of messages to create a procedural memory from.
metadata (dict): Metadata to create a procedural memory from.
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
"""
try:
from langchain_core.messages.utils import (
convert_to_messages, # type: ignore
)
except Exception:
logger.error(
"Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory."
)
raise
logger.info("Creating procedural memory")
parsed_messages = [
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
*messages,
{"role": "user", "content": "Create procedural memory of the above conversation."},
]
try:
if llm is not None:
parsed_messages = convert_to_messages(parsed_messages)
response = await asyncio.to_thread(llm.invoke, input=parsed_messages)
procedural_memory = response.content
else:
procedural_memory = await asyncio.to_thread(self.llm.generate_response, messages=parsed_messages)
except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}")
raise
if metadata is None:
raise ValueError("Metadata cannot be done for procedural memory.")
metadata["memory_type"] = MemoryType.PROCEDURAL.value
# Generate embeddings for the summary
embeddings = await asyncio.to_thread(self.embedding_model.embed, procedural_memory, memory_action="add")
# Create the memory
memory_id = await self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
capture_event("async_mem0._create_procedural_memory", self, {"memory_id": memory_id})
# Return results in the same format as add()
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
return result
async def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
try:
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
except Exception:
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = await asyncio.to_thread(self.embedding_model.embed, data, "update")
await asyncio.to_thread(
self.vector_store.update,
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
await asyncio.to_thread(
self.db.add_history,
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("async_mem0._update_memory", self, {"memory_id": memory_id})
return memory_id
async def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = await asyncio.to_thread(self.vector_store.get, vector_id=memory_id)
prev_value = existing_memory.payload["data"]
await asyncio.to_thread(self.vector_store.delete, vector_id=memory_id)
await asyncio.to_thread(self.db.add_history, memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("async_mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id
async def reset(self):
"""
Reset the memory store asynchronously.
"""
logger.warning("Resetting all memories")
await asyncio.to_thread(self.vector_store.delete_col)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
await asyncio.to_thread(self.db.reset)
capture_event("async_mem0.reset", self)
async def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")