Update add method and prompts (#1775)

This commit is contained in:
Dev Khant
2024-09-04 05:42:35 +05:30
committed by GitHub
parent d113037a4f
commit f21ca9b765
5 changed files with 306 additions and 129 deletions

View File

@@ -2,21 +2,17 @@ import logging
import hashlib
import uuid
import pytz
import json
from datetime import datetime
from typing import Any, Dict
import warnings
from pydantic import ValidationError
from mem0.llms.utils.tools import (
ADD_MEMORY_TOOL,
DELETE_MEMORY_TOOL,
UPDATE_MEMORY_TOOL,
)
from mem0.configs.prompts import MEMORY_DEDUCTION_PROMPT
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_update_memory_messages
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.configs.prompts import get_update_memory_messages
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
@@ -44,7 +40,7 @@ class Memory(MemoryBase):
from mem0.memory.main_graph import MemoryGraph
self.graph = MemoryGraph(self.config)
self.enable_graph = True
capture_event("mem0.init", self)
@classmethod
@@ -58,7 +54,7 @@ class Memory(MemoryBase):
def add(
self,
data,
messages,
user_id=None,
agent_id=None,
run_id=None,
@@ -70,7 +66,7 @@ class Memory(MemoryBase):
Create a new memory.
Args:
data (str): Data to store in the memory.
messages (str or List[Dict[str, str]]): Messages to store in the memory.
user_id (str, optional): ID of the user creating the memory. Defaults to None.
agent_id (str, optional): ID of the agent creating the memory. Defaults to None.
run_id (str, optional): ID of the run creating the memory. Defaults to None.
@@ -83,7 +79,6 @@ class Memory(MemoryBase):
"""
if metadata is None:
metadata = {}
embeddings = self.embedding_model.embed(data)
filters = filters or {}
if user_id:
@@ -98,78 +93,63 @@ class Memory(MemoryBase):
"One of the filters: user_id, agent_id or run_id is required!"
)
if not prompt:
prompt = MEMORY_DEDUCTION_PROMPT.format(user_input=data, metadata=metadata)
extracted_memories = self.llm.generate_response(
messages=[
{
"role": "system",
"content": "You are an expert at deducing facts, preferences and memories from unstructured text.",
},
{"role": "user", "content": prompt},
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
parsed_messages = parse_messages(messages)
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
response = self.llm.generate_response(
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
response_format={"type": "json_object"},
)
try:
new_retrieved_facts = json.loads(response)[
"facts"
]
)
existing_memories = self.vector_store.search(
query=embeddings,
limit=5,
filters=filters,
)
existing_memories = [
MemoryItem(
id=mem.id,
score=mem.score,
metadata=mem.payload,
memory=mem.payload["data"],
except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}")
new_retrieved_facts = []
retrieved_old_memory = []
for new_mem in new_retrieved_facts:
messages_embeddings = self.embedding_model.embed(new_mem)
existing_memories = self.vector_store.search(
query=messages_embeddings,
limit=5,
filters=filters,
)
for mem in existing_memories
]
serialized_existing_memories = [
item.model_dump(include={"id", "memory", "score"})
for item in existing_memories
]
logging.info(f"Total existing memories: {len(existing_memories)}")
messages = get_update_memory_messages(
serialized_existing_memories, extracted_memories
for mem in existing_memories:
retrieved_old_memory.append({"id": mem.id, "text": mem.payload["data"]})
logging.info(f"Total existing memories: {len(retrieved_old_memory)}")
function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts)
new_memories_with_actions = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
# Add tools for noop, add, update, delete memory.
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
response = self.llm.generate_response(messages=messages, tools=tools)
tool_calls = response["tool_calls"]
new_memories_with_actions = json.loads(new_memories_with_actions)
response = []
if tool_calls:
# Create a new memory
available_functions = {
"add_memory": self._create_memory_tool,
"update_memory": self._update_memory_tool,
"delete_memory": self._delete_memory_tool,
}
for tool_call in tool_calls:
function_name = tool_call["name"]
function_to_call = available_functions[function_name]
function_args = tool_call["arguments"]
logging.info(
f"[openai_func] func: {function_name}, args: {function_args}"
)
try:
for resp in new_memories_with_actions["memory"]:
logging.info(resp)
try:
if resp["event"] == "ADD":
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
elif resp["event"] == "DELETE":
self._delete_memory(memory_id=resp["id"])
elif resp["event"] == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
# Pass metadata to the function if it requires it
if function_name in ["add_memory", "update_memory"]:
function_args["metadata"] = metadata
function_result = function_to_call(**function_args)
# Fetch the memory_id from the response
response.append(
{
"id": function_result,
"event": function_name.replace("_memory", ""),
"data": function_args.get("data"),
}
)
capture_event(
"mem0.add.function_call",
self,
{"memory_id": function_result, "function_name": function_name},
)
capture_event("mem0.add", self)
if self.version == "v1.1" and self.enable_graph:
@@ -177,6 +157,7 @@ class Memory(MemoryBase):
self.graph.user_id = user_id
else:
self.graph.user_id = "USER"
data = "\n".join([msg["content"] for msg in messages if "content" in msg])
added_entities = self.graph.add(data, filters)
return {"message": "ok"}
@@ -278,7 +259,7 @@ class Memory(MemoryBase):
}
for mem in memories[0]
]
if self.version == "v1.1":
if self.enable_graph:
graph_entities = self.graph.get_all(filters)
@@ -294,7 +275,6 @@ class Memory(MemoryBase):
stacklevel=2
)
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
@@ -400,7 +380,7 @@ class Memory(MemoryBase):
dict: Updated memory.
"""
capture_event("mem0.update", self, {"memory_id": memory_id})
self._update_memory_tool(memory_id, data)
self._update_memory(memory_id, data)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
@@ -411,7 +391,7 @@ class Memory(MemoryBase):
memory_id (str): ID of the memory to delete.
"""
capture_event("mem0.delete", self, {"memory_id": memory_id})
self._delete_memory_tool(memory_id)
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
@@ -439,7 +419,7 @@ class Memory(MemoryBase):
capture_event("mem0.delete_all", self, {"filters": len(filters)})
memories = self.vector_store.list(filters=filters)[0]
for memory in memories:
self._delete_memory_tool(memory.id)
self._delete_memory(memory.id)
if self.version == "v1.1" and self.enable_graph:
self.graph.delete_all(filters)
@@ -459,7 +439,7 @@ class Memory(MemoryBase):
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory_tool(self, data, metadata=None):
def _create_memory(self, data, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
memory_id = str(uuid.uuid4())
@@ -478,7 +458,7 @@ class Memory(MemoryBase):
)
return memory_id
def _update_memory_tool(self, memory_id, data, metadata=None):
def _update_memory(self, memory_id, data, metadata=None):
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload.get("data")
@@ -513,7 +493,7 @@ class Memory(MemoryBase):
updated_at=new_metadata["updated_at"],
)
def _delete_memory_tool(self, memory_id):
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
prev_value = existing_memory.payload["data"]

View File

@@ -1,14 +1,16 @@
from mem0.configs.prompts import UPDATE_MEMORY_PROMPT
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
def get_update_memory_prompt(existing_memories, memory, template=UPDATE_MEMORY_PROMPT):
return template.format(existing_memories=existing_memories, memory=memory)
def get_fact_retrieval_messages(message):
return FACT_RETRIEVAL_PROMPT, f"Input: {message}"
def get_update_memory_messages(existing_memories, memory):
return [
{
"role": "user",
"content": get_update_memory_prompt(existing_memories, memory),
},
]
def parse_messages(messages):
response = ""
for msg in messages:
if msg["role"] == "system":
response += f"system: {msg['content']}\n"
if msg["role"] == "user":
response += f"user: {msg['content']}\n"
if msg["role"] == "assistant":
response += f"assistant: {msg['content']}\n"
return response