Files
t6_mem0/mem0/memory/main.py
2024-09-03 09:47:35 -07:00

533 lines
18 KiB
Python

import logging
import hashlib
import uuid
import pytz
from datetime import datetime
from typing import Any, Dict
import warnings
from pydantic import ValidationError
from mem0.llms.utils.tools import (
ADD_MEMORY_TOOL,
DELETE_MEMORY_TOOL,
UPDATE_MEMORY_TOOL,
)
from mem0.configs.prompts import MEMORY_DEDUCTION_PROMPT
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_update_memory_messages
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
# Setup user config
setup_config()
class Memory(MemoryBase):
def __init__(self, config: MemoryConfig = MemoryConfig()):
self.config = config
self.embedding_model = EmbedderFactory.create(
self.config.embedder.provider, self.config.embedder.config
)
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
self.llm = LlmFactory.create(self.config.llm.provider, self.config.llm.config)
self.db = SQLiteManager(self.config.history_db_path)
self.collection_name = self.config.vector_store.config.collection_name
self.version = self.config.version
self.enable_graph = False
if self.version == "v1.1" and self.config.graph_store.config:
from mem0.memory.main_graph import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
config = MemoryConfig(**config_dict)
except ValidationError as e:
logging.error(f"Configuration validation error: {e}")
raise
return cls(config)
def add(
self,
data,
user_id=None,
agent_id=None,
run_id=None,
metadata=None,
filters=None,
prompt=None,
):
"""
Create a new memory.
Args:
data (str): Data to store in the memory.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
Returns:
dict: Memory addition operation message.
"""
if metadata is None:
metadata = {}
embeddings = self.embedding_model.embed(data)
filters = filters or {}
if user_id:
filters["user_id"] = metadata["user_id"] = user_id
if agent_id:
filters["agent_id"] = metadata["agent_id"] = agent_id
if run_id:
filters["run_id"] = metadata["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"One of the filters: user_id, agent_id or run_id is required!"
)
if not prompt:
prompt = MEMORY_DEDUCTION_PROMPT.format(user_input=data, metadata=metadata)
extracted_memories = self.llm.generate_response(
messages=[
{
"role": "system",
"content": "You are an expert at deducing facts, preferences and memories from unstructured text.",
},
{"role": "user", "content": prompt},
]
)
existing_memories = self.vector_store.search(
query=embeddings,
limit=5,
filters=filters,
)
existing_memories = [
MemoryItem(
id=mem.id,
score=mem.score,
metadata=mem.payload,
memory=mem.payload["data"],
)
for mem in existing_memories
]
serialized_existing_memories = [
item.model_dump(include={"id", "memory", "score"})
for item in existing_memories
]
logging.info(f"Total existing memories: {len(existing_memories)}")
messages = get_update_memory_messages(
serialized_existing_memories, extracted_memories
)
# Add tools for noop, add, update, delete memory.
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
response = self.llm.generate_response(messages=messages, tools=tools)
tool_calls = response["tool_calls"]
response = []
if tool_calls:
# Create a new memory
available_functions = {
"add_memory": self._create_memory_tool,
"update_memory": self._update_memory_tool,
"delete_memory": self._delete_memory_tool,
}
for tool_call in tool_calls:
function_name = tool_call["name"]
function_to_call = available_functions[function_name]
function_args = tool_call["arguments"]
logging.info(
f"[openai_func] func: {function_name}, args: {function_args}"
)
# Pass metadata to the function if it requires it
if function_name in ["add_memory", "update_memory"]:
function_args["metadata"] = metadata
function_result = function_to_call(**function_args)
# Fetch the memory_id from the response
response.append(
{
"id": function_result,
"event": function_name.replace("_memory", ""),
"data": function_args.get("data"),
}
)
capture_event(
"mem0.add.function_call",
self,
{"memory_id": function_result, "function_name": function_name},
)
capture_event("mem0.add", self)
if self.version == "v1.1" and self.enable_graph:
if user_id:
self.graph.user_id = user_id
else:
self.graph.user_id = "USER"
added_entities = self.graph.add(data, filters)
return {"message": "ok"}
def get(self, memory_id):
"""
Retrieve a memory by ID.
Args:
memory_id (str): ID of the memory to retrieve.
Returns:
dict: Retrieved memory.
"""
capture_event("mem0.get", self, {"memory_id": memory_id})
memory = self.vector_store.get(vector_id=memory_id)
if not memory:
return None
filters = {
key: memory.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if memory.payload.get(key)
}
# Prepare base memory item
memory_item = MemoryItem(
id=memory.id,
memory=memory.payload["data"],
hash=memory.payload.get("hash"),
created_at=memory.payload.get("created_at"),
updated_at=memory.payload.get("updated_at"),
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
additional_metadata = {
k: v for k, v in memory.payload.items() if k not in excluded_keys
}
if additional_metadata:
memory_item["metadata"] = additional_metadata
result = {**memory_item, **filters}
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
Returns:
list: List of all memories.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
capture_event("mem0.get_all", self, {"filters": len(filters), "limit": limit})
memories = self.vector_store.list(filters=filters, limit=limit)
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at"}
all_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
).model_dump(exclude={"score"}),
**{
key: mem.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories[0]
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.get_all(filters)
return {"memories": all_memories, "entities": graph_entities}
else:
return {"memories" : all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
):
"""
Search for memories.
Args:
query (str): Query to search for.
user_id (str, optional): ID of the user to search for. Defaults to None.
agent_id (str, optional): ID of the agent to search for. Defaults to None.
run_id (str, optional): ID of the run to search for. Defaults to None.
limit (int, optional): Limit the number of results. Defaults to 100.
filters (dict, optional): Filters to apply to the search. Defaults to None.
Returns:
list: List of search results.
"""
filters = filters or {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
raise ValueError(
"One of the filters: user_id, agent_id or run_id is required!"
)
capture_event("mem0.search", self, {"filters": len(filters), "limit": limit, "version": self.version})
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
query=embeddings, limit=limit, filters=filters
)
excluded_keys = {
"user_id",
"agent_id",
"run_id",
"hash",
"data",
"created_at",
"updated_at",
}
original_memories = [
{
**MemoryItem(
id=mem.id,
memory=mem.payload["data"],
hash=mem.payload.get("hash"),
created_at=mem.payload.get("created_at"),
updated_at=mem.payload.get("updated_at"),
score=mem.score,
).model_dump(),
**{
key: mem.payload[key]
for key in ["user_id", "agent_id", "run_id"]
if key in mem.payload
},
**(
{
"metadata": {
k: v
for k, v in mem.payload.items()
if k not in excluded_keys
}
}
if any(k for k in mem.payload if k not in excluded_keys)
else {}
),
}
for mem in memories
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.search(query, filters)
return {"memories": original_memories, "entities": graph_entities}
else:
return {"memories" : original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
Args:
memory_id (str): ID of the memory to update.
data (dict): Data to update the memory with.
Returns:
dict: Updated memory.
"""
capture_event("mem0.update", self, {"memory_id": memory_id})
self._update_memory_tool(memory_id, data)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
Delete a memory by ID.
Args:
memory_id (str): ID of the memory to delete.
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory_tool(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
Args:
user_id (str, optional): ID of the user to delete memories for. Defaults to None.
agent_id (str, optional): ID of the agent to delete memories for. Defaults to None.
run_id (str, optional): ID of the run to delete memories for. Defaults to None.
"""
filters = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
if run_id:
filters["run_id"] = run_id
if not filters:
raise ValueError(
"At least one filter is required to delete all memories. If you want to delete all memories, use the `reset()` method."
)
capture_event("mem0.delete_all", self, {"filters": len(filters)})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory_tool(memory.id)
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)
return {'message': 'Memories deleted successfully!'}
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
Args:
memory_id (str): ID of the memory to get history for.
Returns:
list: List of changes for the memory.
"""
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory_tool(self, data, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
memory_id = str(uuid.uuid4())
metadata = metadata or {}
metadata["data"] = data
metadata["hash"] = hashlib.md5(data.encode()).hexdigest()
metadata["created_at"] = datetime.now(pytz.timezone("US/Pacific")).isoformat()
self.vector_store.insert(
vectors=[embeddings],
ids=[memory_id],
payloads=[metadata],
)
self.db.add_history(
memory_id, None, data, "ADD", created_at=metadata["created_at"]
)
return memory_id
def _update_memory_tool(self, memory_id, data, metadata=None):
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload.get("data")
new_metadata = metadata or {}
new_metadata["data"] = data
new_metadata["hash"] = existing_memory.payload.get("hash")
new_metadata["created_at"] = existing_memory.payload.get("created_at")
new_metadata["updated_at"] = datetime.now(
pytz.timezone("US/Pacific")
).isoformat()
if "user_id" in existing_memory.payload:
new_metadata["user_id"] = existing_memory.payload["user_id"]
if "agent_id" in existing_memory.payload:
new_metadata["agent_id"] = existing_memory.payload["agent_id"]
if "run_id" in existing_memory.payload:
new_metadata["run_id"] = existing_memory.payload["run_id"]
embeddings = self.embedding_model.embed(data)
self.vector_store.update(
vector_id=memory_id,
vector=embeddings,
payload=new_metadata,
)
logging.info(f"Updating memory with ID {memory_id=} with {data=}")
self.db.add_history(
memory_id,
prev_value,
data,
"UPDATE",
created_at=new_metadata["created_at"],
updated_at=new_metadata["updated_at"],
)
def _delete_memory_tool(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload["data"]
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
def reset(self):
"""
Reset the memory store.
"""
self.vector_store.delete_col()
self.db.reset()
capture_event("mem0.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")