Update add method and prompts (#1775)
This commit is contained in:
@@ -2,21 +2,17 @@ import logging
|
||||
import hashlib
|
||||
import uuid
|
||||
import pytz
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
import warnings
|
||||
from pydantic import ValidationError
|
||||
from mem0.llms.utils.tools import (
|
||||
ADD_MEMORY_TOOL,
|
||||
DELETE_MEMORY_TOOL,
|
||||
UPDATE_MEMORY_TOOL,
|
||||
)
|
||||
from mem0.configs.prompts import MEMORY_DEDUCTION_PROMPT
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import get_update_memory_messages
|
||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||
|
||||
@@ -44,7 +40,7 @@ class Memory(MemoryBase):
|
||||
from mem0.memory.main_graph import MemoryGraph
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
@classmethod
|
||||
@@ -58,7 +54,7 @@ class Memory(MemoryBase):
|
||||
|
||||
def add(
|
||||
self,
|
||||
data,
|
||||
messages,
|
||||
user_id=None,
|
||||
agent_id=None,
|
||||
run_id=None,
|
||||
@@ -70,7 +66,7 @@ class Memory(MemoryBase):
|
||||
Create a new memory.
|
||||
|
||||
Args:
|
||||
data (str): Data to store in the memory.
|
||||
messages (str or List[Dict[str, str]]): Messages to store in the memory.
|
||||
user_id (str, optional): ID of the user creating the memory. Defaults to None.
|
||||
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
|
||||
run_id (str, optional): ID of the run creating the memory. Defaults to None.
|
||||
@@ -83,7 +79,6 @@ class Memory(MemoryBase):
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
|
||||
filters = filters or {}
|
||||
if user_id:
|
||||
@@ -98,78 +93,63 @@ class Memory(MemoryBase):
|
||||
"One of the filters: user_id, agent_id or run_id is required!"
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
prompt = MEMORY_DEDUCTION_PROMPT.format(user_input=data, metadata=metadata)
|
||||
extracted_memories = self.llm.generate_response(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an expert at deducing facts, preferences and memories from unstructured text.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
|
||||
|
||||
response = self.llm.generate_response(
|
||||
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
|
||||
try:
|
||||
new_retrieved_facts = json.loads(response)[
|
||||
"facts"
|
||||
]
|
||||
)
|
||||
existing_memories = self.vector_store.search(
|
||||
query=embeddings,
|
||||
limit=5,
|
||||
filters=filters,
|
||||
)
|
||||
existing_memories = [
|
||||
MemoryItem(
|
||||
id=mem.id,
|
||||
score=mem.score,
|
||||
metadata=mem.payload,
|
||||
memory=mem.payload["data"],
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_retrieved_facts: {e}")
|
||||
new_retrieved_facts = []
|
||||
|
||||
retrieved_old_memory = []
|
||||
for new_mem in new_retrieved_facts:
|
||||
messages_embeddings = self.embedding_model.embed(new_mem)
|
||||
existing_memories = self.vector_store.search(
|
||||
query=messages_embeddings,
|
||||
limit=5,
|
||||
filters=filters,
|
||||
)
|
||||
for mem in existing_memories
|
||||
]
|
||||
serialized_existing_memories = [
|
||||
item.model_dump(include={"id", "memory", "score"})
|
||||
for item in existing_memories
|
||||
]
|
||||
logging.info(f"Total existing memories: {len(existing_memories)}")
|
||||
messages = get_update_memory_messages(
|
||||
serialized_existing_memories, extracted_memories
|
||||
for mem in existing_memories:
|
||||
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
|
||||
|
||||
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
|
||||
|
||||
function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts)
|
||||
new_memories_with_actions = self.llm.generate_response(
|
||||
messages=[{"role": "user", "content": function_calling_prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
# Add tools for noop, add, update, delete memory.
|
||||
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
|
||||
response = self.llm.generate_response(messages=messages, tools=tools)
|
||||
tool_calls = response["tool_calls"]
|
||||
new_memories_with_actions = json.loads(new_memories_with_actions)
|
||||
|
||||
response = []
|
||||
if tool_calls:
|
||||
# Create a new memory
|
||||
available_functions = {
|
||||
"add_memory": self._create_memory_tool,
|
||||
"update_memory": self._update_memory_tool,
|
||||
"delete_memory": self._delete_memory_tool,
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["name"]
|
||||
function_to_call = available_functions[function_name]
|
||||
function_args = tool_call["arguments"]
|
||||
logging.info(
|
||||
f"[openai_func] func: {function_name}, args: {function_args}"
|
||||
)
|
||||
try:
|
||||
for resp in new_memories_with_actions["memory"]:
|
||||
logging.info(resp)
|
||||
try:
|
||||
if resp["event"] == "ADD":
|
||||
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
elif resp["event"] == "UPDATE":
|
||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||
elif resp["event"] == "DELETE":
|
||||
self._delete_memory(memory_id=resp["id"])
|
||||
elif resp["event"] == "NONE":
|
||||
logging.info("NOOP for Memory.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
|
||||
# Pass metadata to the function if it requires it
|
||||
if function_name in ["add_memory", "update_memory"]:
|
||||
function_args["metadata"] = metadata
|
||||
|
||||
function_result = function_to_call(**function_args)
|
||||
# Fetch the memory_id from the response
|
||||
response.append(
|
||||
{
|
||||
"id": function_result,
|
||||
"event": function_name.replace("_memory", ""),
|
||||
"data": function_args.get("data"),
|
||||
}
|
||||
)
|
||||
capture_event(
|
||||
"mem0.add.function_call",
|
||||
self,
|
||||
{"memory_id": function_result, "function_name": function_name},
|
||||
)
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
@@ -177,6 +157,7 @@ class Memory(MemoryBase):
|
||||
self.graph.user_id = user_id
|
||||
else:
|
||||
self.graph.user_id = "USER"
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg])
|
||||
added_entities = self.graph.add(data, filters)
|
||||
|
||||
return {"message": "ok"}
|
||||
@@ -278,7 +259,7 @@ class Memory(MemoryBase):
|
||||
}
|
||||
for mem in memories[0]
|
||||
]
|
||||
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
graph_entities = self.graph.get_all(filters)
|
||||
@@ -294,7 +275,6 @@ class Memory(MemoryBase):
|
||||
stacklevel=2
|
||||
)
|
||||
return all_memories
|
||||
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
@@ -400,7 +380,7 @@ class Memory(MemoryBase):
|
||||
dict: Updated memory.
|
||||
"""
|
||||
capture_event("mem0.update", self, {"memory_id": memory_id})
|
||||
self._update_memory_tool(memory_id, data)
|
||||
self._update_memory(memory_id, data)
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
def delete(self, memory_id):
|
||||
@@ -411,7 +391,7 @@ class Memory(MemoryBase):
|
||||
memory_id (str): ID of the memory to delete.
|
||||
"""
|
||||
capture_event("mem0.delete", self, {"memory_id": memory_id})
|
||||
self._delete_memory_tool(memory_id)
|
||||
self._delete_memory(memory_id)
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
@@ -439,7 +419,7 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.delete_all", self, {"filters": len(filters)})
|
||||
memories = self.vector_store.list(filters=filters)[0]
|
||||
for memory in memories:
|
||||
self._delete_memory_tool(memory.id)
|
||||
self._delete_memory(memory.id)
|
||||
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
self.graph.delete_all(filters)
|
||||
@@ -459,7 +439,7 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||
return self.db.get_history(memory_id)
|
||||
|
||||
def _create_memory_tool(self, data, metadata=None):
|
||||
def _create_memory(self, data, metadata=None):
|
||||
logging.info(f"Creating memory with {data=}")
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
memory_id = str(uuid.uuid4())
|
||||
@@ -478,7 +458,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return memory_id
|
||||
|
||||
def _update_memory_tool(self, memory_id, data, metadata=None):
|
||||
def _update_memory(self, memory_id, data, metadata=None):
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
prev_value = existing_memory.payload.get("data")
|
||||
|
||||
@@ -513,7 +493,7 @@ class Memory(MemoryBase):
|
||||
updated_at=new_metadata["updated_at"],
|
||||
)
|
||||
|
||||
def _delete_memory_tool(self, memory_id):
|
||||
def _delete_memory(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
prev_value = existing_memory.payload["data"]
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from mem0.configs.prompts import UPDATE_MEMORY_PROMPT
|
||||
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
|
||||
|
||||
|
||||
def get_update_memory_prompt(existing_memories, memory, template=UPDATE_MEMORY_PROMPT):
|
||||
return template.format(existing_memories=existing_memories, memory=memory)
|
||||
def get_fact_retrieval_messages(message):
|
||||
return FACT_RETRIEVAL_PROMPT, f"Input: {message}"
|
||||
|
||||
|
||||
def get_update_memory_messages(existing_memories, memory):
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_update_memory_prompt(existing_memories, memory),
|
||||
},
|
||||
]
|
||||
def parse_messages(messages):
|
||||
response = ""
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
response += f"system: {msg['content']}\n"
|
||||
if msg["role"] == "user":
|
||||
response += f"user: {msg['content']}\n"
|
||||
if msg["role"] == "assistant":
|
||||
response += f"assistant: {msg['content']}\n"
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user