Files
t6_mem0/mem0/memory/main.py
2025-04-07 11:31:16 +05:30

744 lines
29 KiB
Python

import concurrent
import hashlib
import json
import logging
import uuid
import warnings
from datetime import datetime
from typing import Any, Dict
import pytz
from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.enums import MemoryType
from mem0.configs.prompts import PROCEDURAL_MEMORY_SYSTEM_PROMPT, get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages, parse_vision_messages, remove_code_blocks
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
# Setup user config
setup_config()
logger = logging.getLogger(__name__)
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.custom_fact_extraction_prompt = self.config.custom_fact_extraction_prompt
self.custom_update_memory_prompt = self.config.custom_update_memory_prompt
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.api_version = self.config.version
self.enable_graph = False
if self.config.graph_store.config:
from mem0.memory.graph_memory import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
config = cls._process_config(config_dict)
config = MemoryConfig(**config_dict)
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
return cls(config)
@staticmethod
def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "graph_store" in config_dict:
if "vector_store" not in config_dict and "embedder" in config_dict:
config_dict["vector_store"] = {}
config_dict["vector_store"]["config"] = {}
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][
"embedding_dims"
]
try:
return config_dict
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise
def add(
self,
messages,
user_id=None,
agent_id=None,
run_id=None,
metadata=None,
filters=None,
infer=True,
memory_type=None,
prompt=None,
llm=None,
):
"""
Create a new memory.
Args:
messages (str or List[Dict[str, str]]): Messages to store in the memory.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
infer (bool, optional): Whether to infer the memories. Defaults to True.
memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
Returns:
dict: A dictionary containing the result of the memory addition operation.
result: dict of affected events with each dict has the following key:
'memories': affected memories
'graph': affected graph memories
'memories' and 'graph' is a dict, each with following subkeys:
'add': added memory
'update': updated memory
'delete': deleted memory
"""
if metadata is None:
metadata = {}
filters = filters or {}
if user_id:
filters["user_id"] = metadata["user_id"] = user_id
if agent_id:
filters["agent_id"] = metadata["agent_id"] = agent_id
if run_id:
filters["run_id"] = metadata["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
raise ValueError(
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
)
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
results = self._create_procedural_memory(messages, metadata=metadata, llm=llm, prompt=prompt)
return results
if self.config.llm.config.get("enable_vision"):
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
else:
messages = parse_vision_messages(messages)
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, infer)
future2 = executor.submit(self._add_to_graph, messages, filters)
concurrent.futures.wait([future1, future2])
vector_store_result = future1.result()
graph_result = future2.result()
if self.api_version == "v1.0":
warnings.warn(
"The current add API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return vector_store_result
if self.enable_graph:
return {
"results": vector_store_result,
"relations": graph_result,
}
return {"results": vector_store_result}
def _add_to_vector_store(self, messages, metadata, filters, infer):
if not infer:
returned_memories = []
for message in messages:
if message["role"] != "system":
message_embeddings = self.embedding_model.embed(message["content"], "add")
memory_id = self._create_memory(message["content"], message_embeddings, metadata)
returned_memories.append({"id": memory_id, "memory": message["content"], "event": "ADD"})
return returned_memories
parsed_messages = parse_messages(messages)
if self.custom_fact_extraction_prompt:
system_prompt = self.custom_fact_extraction_prompt
user_prompt = f"Input:\n{parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = self.llm.generate_response(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object"},
)
try:
response = remove_code_blocks(response)
new_retrieved_facts = json.loads(response)["facts"]
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = []
new_message_embeddings = {}
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
query=new_mem,
vectors=messages_embeddings,
limit=5,
filters=filters,
)
for mem in existing_memories:
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
unique_data = {}
for item in retrieved_old_memory:
unique_data[item["id"]] = item
retrieved_old_memory = list(unique_data.values())
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
# mapping UUIDs with integers for handling UUID hallucinations
temp_uuid_mapping = {}
for idx, item in enumerate(retrieved_old_memory):
temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx)
function_calling_prompt = get_update_memory_messages(
retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt
)
try:
new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
new_memories_with_actions = []
try:
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
new_memories_with_actions = json.loads(new_memories_with_actions)
except Exception as e:
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = []
returned_memories = []
try:
for resp in new_memories_with_actions.get("memory", []):
logging.info(resp)
try:
if not resp.get("text"):
logging.info("Skipping memory entry because of empty `text` field.")
continue
elif resp.get("event") == "ADD":
memory_id = self._create_memory(
data=resp.get("text"), existing_embeddings=new_message_embeddings, metadata=metadata
)
returned_memories.append(
{
"id": memory_id,
"memory": resp.get("text"),
"event": resp.get("event"),
}
)
elif resp.get("event") == "UPDATE":
self._update_memory(
memory_id=temp_uuid_mapping[resp["id"]],
data=resp.get("text"),
existing_embeddings=new_message_embeddings,
metadata=metadata,
)
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": resp.get("text"),
"event": resp.get("event"),
"previous_memory": resp.get("old_memory"),
}
)
elif resp.get("event") == "DELETE":
self._delete_memory(memory_id=temp_uuid_mapping[resp.get("id")])
returned_memories.append(
{
"id": temp_uuid_mapping[resp.get("id")],
"memory": resp.get("text"),
"event": resp.get("event"),
}
)
elif resp.get("event") == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
return returned_memories
def _add_to_graph(self, messages, filters):
added_entities = []
if self.enable_graph:
if filters.get("user_id") is None:
filters["user_id"] = "user"
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = self.graph.add(data, filters)
return added_entities
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("mem0.get", self, {"memory_id": memory_id})
memory = self.vector_store.get(vector_id=memory_id)
if not memory:
return None
filters = {key: memory.payload[key] for key in ["user_id", "agent_id", "run_id"] if memory.payload.get(key)}
# Prepare base memory item
memory_item = MemoryItem(
id=memory.id,
memory=memory.payload["data"],
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"}
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
result = {**memory_item, **filters}
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
Returns:
list: List of all memories.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"limit": limit, "keys": list(filters.keys())})
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._get_all_from_vector_store, filters, limit)
future_graph_entities = executor.submit(self.graph.get_all, filters, limit) if self.enable_graph else None
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
all_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": all_memories, "relations": graph_entities}
if self.api_version == "v1.0":
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return all_memories
else:
return {"results": all_memories}
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
"id",
}
all_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories[0]
]
return all_memories
def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None):
"""
Search for memories.
Args:
query (str): Query to search for.
user_id (str, optional): ID of the user to search for. Defaults to None.
agent_id (str, optional): ID of the agent to search for. Defaults to None.
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: List of search results.
"""
filters = filters or {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
capture_event(
"mem0.search",
self,
{"limit": limit, "version": self.api_version, "keys": list(filters.keys())},
)
with concurrent.futures.ThreadPoolExecutor() as executor:
future_memories = executor.submit(self._search_vector_store, query, filters, limit)
future_graph_entities = (
executor.submit(self.graph.search, query, filters, limit) if self.enable_graph else None
)
concurrent.futures.wait(
[future_memories, future_graph_entities] if future_graph_entities else [future_memories]
)
original_memories = future_memories.result()
graph_entities = future_graph_entities.result() if future_graph_entities else None
if self.enable_graph:
return {"results": original_memories, "relations": graph_entities}
if self.api_version == "v1.0":
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2,
)
return original_memories
else:
return {"results": original_memories}
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query, "search")
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
"id",
}
original_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{key: mem.payload[key] for key in ["user_id", "agent_id", "run_id"] if key in mem.payload},
**(
{"metadata": {k: v for k, v in mem.payload.items() if k not in excluded_keys}}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories
]
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (dict): Data to update the memory with.
Returns:
dict: Updated memory.
"""
capture_event("mem0.update", self, {"memory_id": memory_id})
existing_embeddings = {data: self.embedding_model.embed(data, "update")}
self._update_memory(memory_id, data, existing_embeddings)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem0.delete_all", self, {"keys": list(filters.keys())})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory(memory.id)
logger.info(f"Deleted {len(memories)} memories")
if self.enable_graph:
self.graph.delete_all(filters)
return {"message": "Memories deleted successfully!"}
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory(self, data, existing_embeddings, metadata=None):
logging.debug(f"Creating memory with {data=}")
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, memory_action="add")
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
self.vector_store.insert(
vectors=[embeddings],
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(memory_id, None, data, "ADD", created_at=metadata["created_at"])
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
return memory_id
def _create_procedural_memory(self, messages, metadata=None, llm=None, prompt=None):
"""
Create a procedural memory
Args:
messages (list): List of messages to create a procedural memory from.
metadata (dict): Metadata to create a procedural memory from.
llm (BaseChatModel, optional): LLM class to use for generating procedural memories. Defaults to None. Useful when user is using LangChain ChatModel.
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
"""
try:
from langchain_core.messages.utils import convert_to_messages # type: ignore
except Exception:
logger.error(
"Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory."
)
raise
logger.info("Creating procedural memory")
parsed_messages = [
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
*messages,
{"role": "user", "content": "Create procedural memory of the above conversation."},
]
try:
if llm is not None:
parsed_messages = convert_to_messages(parsed_messages)
response = llm.invoke(input=parsed_messages)
procedural_memory = response.content
else:
procedural_memory = self.llm.generate_response(messages=parsed_messages)
except Exception as e:
logger.error(f"Error generating procedural memory summary: {e}")
raise
if metadata is None:
raise ValueError("Metadata cannot be done for procedural memory.")
metadata["memory_type"] = MemoryType.PROCEDURAL.value
# Generate embeddings for the summary
embeddings = self.embedding_model.embed(procedural_memory, memory_action="add")
# Create the memory
memory_id = self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id})
# Return results in the same format as add()
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
return result
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
logger.info(f"Updating memory with {data=}")
try:
existing_memory = self.vector_store.get(vector_id=memory_id)
except Exception:
raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'")
prev_value = existing_memory.payload.get("data")
new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
if data in existing_embeddings:
embeddings = existing_embeddings[data]
else:
embeddings = self.embedding_model.embed(data, "update")
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logger.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
capture_event("mem0._update_memory", self, {"memory_id": memory_id})
return memory_id
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
capture_event("mem0._delete_memory", self, {"memory_id": memory_id})
return memory_id
def reset(self):
"""
Reset the memory store.
"""
logger.warning("Resetting all memories")
self.vector_store.delete_col()
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.db.reset()
capture_event("mem0.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")