Add support for procedural memory (#2460)
This commit is contained in:
@@ -11,17 +11,15 @@ import pytz
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.configs.enums import MemoryType
|
||||
from mem0.configs.prompts import (PROCEDURAL_MEMORY_SYSTEM_PROMPT,
|
||||
get_update_memory_messages)
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import (
|
||||
get_fact_retrieval_messages,
|
||||
parse_messages,
|
||||
parse_vision_messages,
|
||||
remove_code_blocks,
|
||||
)
|
||||
from mem0.memory.utils import (get_fact_retrieval_messages, parse_messages,
|
||||
parse_vision_messages, remove_code_blocks)
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
|
||||
# Setup user config
|
||||
@@ -89,6 +87,7 @@ class Memory(MemoryBase):
|
||||
metadata=None,
|
||||
filters=None,
|
||||
infer=True,
|
||||
memory_type=None,
|
||||
prompt=None,
|
||||
):
|
||||
"""
|
||||
@@ -102,8 +101,8 @@ class Memory(MemoryBase):
|
||||
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
infer (bool, optional): Whether to infer the memories. Defaults to True.
|
||||
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
|
||||
|
||||
memory_type (str, optional): Type of memory to create. Defaults to None. By default, it creates the short term memories and long term (semantic and episodic) memories. Pass "procedural_memory" to create procedural memories.
|
||||
prompt (str, optional): Prompt to use for the memory creation. Defaults to None.
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the memory addition operation.
|
||||
result: dict of affected events with each dict has the following key:
|
||||
@@ -131,9 +130,18 @@ class Memory(MemoryBase):
|
||||
if not any(key in filters for key in ("user_id", "agent_id", "run_id")):
|
||||
raise ValueError("One of the filters: user_id, agent_id or run_id is required!")
|
||||
|
||||
if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value:
|
||||
raise ValueError(
|
||||
f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} to create procedural memories."
|
||||
)
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value:
|
||||
results = self._create_procedural_memory(messages, metadata, prompt)
|
||||
return results
|
||||
|
||||
if self.config.llm.config.get("enable_vision"):
|
||||
messages = parse_vision_messages(messages, self.llm, self.config.llm.config.get("vision_details"))
|
||||
else:
|
||||
@@ -595,11 +603,11 @@ class Memory(MemoryBase):
|
||||
return self.db.get_history(memory_id)
|
||||
|
||||
def _create_memory(self, data, existing_embeddings, metadata=None):
|
||||
logging.info(f"Creating memory with {data=}")
|
||||
logging.debug(f"Creating memory with {data=}")
|
||||
if data in existing_embeddings:
|
||||
embeddings = existing_embeddings[data]
|
||||
else:
|
||||
embeddings = self.embedding_model.embed(data, "add")
|
||||
embeddings = self.embedding_model.embed(data, memory_action="add")
|
||||
memory_id = str(uuid.uuid4())
|
||||
metadata = metadata or {}
|
||||
metadata["data"] = data
|
||||
@@ -615,6 +623,50 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0._create_memory", self, {"memory_id": memory_id})
|
||||
return memory_id
|
||||
|
||||
def _create_procedural_memory(self, messages, metadata, llm=None, prompt=None):
|
||||
"""
|
||||
Create a procedural memory
|
||||
"""
|
||||
try:
|
||||
from langchain_core.messages.utils import convert_to_messages # type: ignore
|
||||
except Exception:
|
||||
logger.error("Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory.")
|
||||
raise
|
||||
|
||||
logger.info("Creating procedural memory")
|
||||
|
||||
parsed_messages = [
|
||||
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
|
||||
*messages,
|
||||
{"role": "user", "content": "Create procedural memory of the above conversation."},
|
||||
]
|
||||
|
||||
try:
|
||||
if llm is not None:
|
||||
parsed_messages = convert_to_messages(parsed_messages)
|
||||
response = llm.invoke(messages=parsed_messages)
|
||||
procedural_memory = response.content
|
||||
else:
|
||||
procedural_memory = self.llm.generate_response(messages=parsed_messages)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating procedural memory summary: {e}")
|
||||
raise
|
||||
|
||||
if metadata is None:
|
||||
raise ValueError("Metadata cannot be done for procedural memory.")
|
||||
|
||||
metadata["memory_type"] = MemoryType.PROCEDURAL.value
|
||||
# Generate embeddings for the summary
|
||||
embeddings = self.embedding_model.embed(procedural_memory, memory_action="add")
|
||||
# Create the memory
|
||||
memory_id = self._create_memory(procedural_memory, {procedural_memory: embeddings}, metadata=metadata)
|
||||
capture_event("mem0._create_procedural_memory", self, {"memory_id": memory_id})
|
||||
|
||||
# Return results in the same format as add()
|
||||
result = {"results": [{"id": memory_id, "memory": procedural_memory, "event": "ADD"}]}
|
||||
|
||||
return result
|
||||
|
||||
def _update_memory(self, memory_id, data, existing_embeddings, metadata=None):
|
||||
logger.info(f"Updating memory with {data=}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user