Modified the return statement for ADD call | Added tests to main.py and graph_memory.py (#1812)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from mem0.llms.openai import OpenAILLM
|
||||
|
||||
UPDATE_GRAPH_PROMPT = """
|
||||
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
|
||||
@@ -67,42 +66,3 @@ def get_update_memory_messages(existing_memories, memory):
|
||||
"content": get_update_memory_prompt(existing_memories, memory, UPDATE_GRAPH_PROMPT),
|
||||
},
|
||||
]
|
||||
|
||||
def get_search_results(entities, query):
|
||||
|
||||
search_graph_prompt = f"""
|
||||
You are an expert at searching through graph entity memories.
|
||||
When provided with existing graph entities and a query, your task is to search through the provided graph entities to find the most relevant information from the graph entities related to the query.
|
||||
The output should be from the graph entities only.
|
||||
|
||||
Here are the details of the task:
|
||||
- Existing Graph Entities (source -> relationship -> target):
|
||||
{entities}
|
||||
|
||||
- Query: {query}
|
||||
|
||||
The output should be from the graph entities only.
|
||||
The output should be in the following JSON format:
|
||||
{{
|
||||
"search_results": [
|
||||
{{
|
||||
"source_node": "source_node",
|
||||
"relationship": "relationship",
|
||||
"target_node": "target_node"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": search_graph_prompt,
|
||||
}
|
||||
]
|
||||
|
||||
llm = OpenAILLM()
|
||||
|
||||
results = llm.generate_response(messages=messages, response_format={"type": "json_object"})
|
||||
|
||||
return results
|
||||
|
||||
@@ -101,6 +101,8 @@ class MemoryGraph:
|
||||
elif item['name'] == "noop":
|
||||
continue
|
||||
|
||||
returned_entities = []
|
||||
|
||||
for item in to_be_added:
|
||||
source = item['source'].lower().replace(" ", "_")
|
||||
source_type = item['source_type'].lower().replace(" ", "_")
|
||||
@@ -108,6 +110,12 @@ class MemoryGraph:
|
||||
destination = item['destination'].lower().replace(" ", "_")
|
||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||
|
||||
returned_entities.append({
|
||||
"source" : source,
|
||||
"relationship" : relation,
|
||||
"target" : destination
|
||||
})
|
||||
|
||||
# Create embeddings
|
||||
source_embedding = self.embedding_model.embed(source)
|
||||
dest_embedding = self.embedding_model.embed(destination)
|
||||
@@ -137,6 +145,7 @@ class MemoryGraph:
|
||||
|
||||
logger.info(f"Added {len(to_be_added)} new memories to the graph")
|
||||
|
||||
return returned_entities
|
||||
|
||||
def _search(self, query, filters):
|
||||
_tools = [SEARCH_TOOL]
|
||||
@@ -155,8 +164,10 @@ class MemoryGraph:
|
||||
|
||||
for item in search_results['tool_calls']:
|
||||
if item['name'] == "search":
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
relation_list.extend(item['arguments']['relations'])
|
||||
try:
|
||||
node_list.extend(item['arguments']['nodes'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search tool: {e}")
|
||||
|
||||
node_list = list(set(node_list))
|
||||
relation_list = list(set(relation_list))
|
||||
@@ -228,8 +239,8 @@ class MemoryGraph:
|
||||
for item in reranked_results:
|
||||
search_results.append({
|
||||
"source": item[0],
|
||||
"relation": item[1],
|
||||
"destination": item[2]
|
||||
"relationship": item[1],
|
||||
"target": item[2]
|
||||
})
|
||||
|
||||
logger.info(f"Returned {len(search_results)} search results")
|
||||
|
||||
@@ -2,7 +2,6 @@ import concurrent
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
@@ -11,14 +10,14 @@ from typing import Any, Dict
|
||||
import pytz
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||
from mem0.configs.prompts import get_update_memory_messages
|
||||
from mem0.memory.base import MemoryBase
|
||||
from mem0.memory.setup import setup_config
|
||||
from mem0.memory.storage import SQLiteManager
|
||||
from mem0.memory.telemetry import capture_event
|
||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||
|
||||
# Setup user config
|
||||
setup_config()
|
||||
@@ -49,6 +48,7 @@ class Memory(MemoryBase):
|
||||
|
||||
capture_event("mem0.init", self)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config_dict: Dict[str, Any]):
|
||||
try:
|
||||
@@ -58,6 +58,7 @@ class Memory(MemoryBase):
|
||||
raise
|
||||
return cls(config)
|
||||
|
||||
|
||||
def add(
|
||||
self,
|
||||
messages,
|
||||
@@ -81,7 +82,7 @@ class Memory(MemoryBase):
|
||||
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: Memory addition operation message.
|
||||
dict: A dictionary containing the result of the memory addition operation.
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
@@ -102,17 +103,31 @@ class Memory(MemoryBase):
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
|
||||
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
|
||||
future2 = executor.submit(self._add_to_graph, messages, filters)
|
||||
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
concurrent.futures.wait([future1, future2])
|
||||
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
vector_store_result = future1.result()
|
||||
graph_result = future2.result()
|
||||
|
||||
if self.version == "v1.1":
|
||||
return {
|
||||
"results" : vector_store_result,
|
||||
"relations" : graph_result,
|
||||
}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current add API output format is deprecated. "
|
||||
"To use the latest format, set `api_version='v1.1'`. "
|
||||
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
return {"message": "ok"}
|
||||
|
||||
|
||||
return {"message": "ok"}
|
||||
|
||||
def _add_to_vector_store(self, messages, metadata, filters):
|
||||
parsed_messages = parse_messages(messages)
|
||||
|
||||
@@ -151,16 +166,30 @@ class Memory(MemoryBase):
|
||||
)
|
||||
new_memories_with_actions = json.loads(new_memories_with_actions)
|
||||
|
||||
returned_memories = []
|
||||
try:
|
||||
for resp in new_memories_with_actions["memory"]:
|
||||
logging.info(resp)
|
||||
try:
|
||||
if resp["event"] == "ADD":
|
||||
self._create_memory(data=resp["text"], metadata=metadata)
|
||||
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
elif resp["event"] == "UPDATE":
|
||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
"previous_memory" : resp["old_memory"],
|
||||
})
|
||||
elif resp["event"] == "DELETE":
|
||||
self._delete_memory(memory_id=resp["id"])
|
||||
returned_memories.append({
|
||||
"memory" : resp["text"],
|
||||
"event" : resp["event"],
|
||||
})
|
||||
elif resp["event"] == "NONE":
|
||||
logging.info("NOOP for Memory.")
|
||||
except Exception as e:
|
||||
@@ -170,7 +199,11 @@ class Memory(MemoryBase):
|
||||
|
||||
capture_event("mem0.add", self)
|
||||
|
||||
return returned_memories
|
||||
|
||||
|
||||
def _add_to_graph(self, messages, filters):
|
||||
added_entities = []
|
||||
if self.version == "v1.1" and self.enable_graph:
|
||||
if filters["user_id"]:
|
||||
self.graph.user_id = filters["user_id"]
|
||||
@@ -179,6 +212,9 @@ class Memory(MemoryBase):
|
||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||
self.graph.add(data, filters)
|
||||
|
||||
return added_entities
|
||||
|
||||
|
||||
def get(self, memory_id):
|
||||
"""
|
||||
Retrieve a memory by ID.
|
||||
@@ -229,6 +265,7 @@ class Memory(MemoryBase):
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||
"""
|
||||
List all memories.
|
||||
@@ -255,9 +292,9 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"memories": all_memories, "entities": graph_entities}
|
||||
return {"results": all_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"memories": all_memories}
|
||||
return {"results": all_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
@@ -268,6 +305,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return all_memories
|
||||
|
||||
|
||||
def _get_all_from_vector_store(self, filters, limit):
|
||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||
|
||||
@@ -302,6 +340,7 @@ class Memory(MemoryBase):
|
||||
]
|
||||
return all_memories
|
||||
|
||||
|
||||
def search(
|
||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||
):
|
||||
@@ -343,9 +382,9 @@ class Memory(MemoryBase):
|
||||
|
||||
if self.version == "v1.1":
|
||||
if self.enable_graph:
|
||||
return {"memories": original_memories, "entities": graph_entities}
|
||||
return {"results": original_memories, "relations": graph_entities}
|
||||
else:
|
||||
return {"memories" : original_memories}
|
||||
return {"results" : original_memories}
|
||||
else:
|
||||
warnings.warn(
|
||||
"The current get_all API output format is deprecated. "
|
||||
@@ -356,6 +395,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return original_memories
|
||||
|
||||
|
||||
def _search_vector_store(self, query, filters, limit):
|
||||
embeddings = self.embedding_model.embed(query)
|
||||
memories = self.vector_store.search(
|
||||
@@ -404,6 +444,7 @@ class Memory(MemoryBase):
|
||||
|
||||
return original_memories
|
||||
|
||||
|
||||
def update(self, memory_id, data):
|
||||
"""
|
||||
Update a memory by ID.
|
||||
@@ -419,6 +460,7 @@ class Memory(MemoryBase):
|
||||
self._update_memory(memory_id, data)
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
|
||||
def delete(self, memory_id):
|
||||
"""
|
||||
Delete a memory by ID.
|
||||
@@ -430,6 +472,7 @@ class Memory(MemoryBase):
|
||||
self._delete_memory(memory_id)
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
|
||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||
"""
|
||||
Delete all memories.
|
||||
@@ -464,6 +507,7 @@ class Memory(MemoryBase):
|
||||
|
||||
return {'message': 'Memories deleted successfully!'}
|
||||
|
||||
|
||||
def history(self, memory_id):
|
||||
"""
|
||||
Get the history of changes for a memory by ID.
|
||||
@@ -477,6 +521,7 @@ class Memory(MemoryBase):
|
||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||
return self.db.get_history(memory_id)
|
||||
|
||||
|
||||
def _create_memory(self, data, metadata=None):
|
||||
logging.info(f"Creating memory with {data=}")
|
||||
embeddings = self.embedding_model.embed(data)
|
||||
@@ -496,6 +541,7 @@ class Memory(MemoryBase):
|
||||
)
|
||||
return memory_id
|
||||
|
||||
|
||||
def _update_memory(self, memory_id, data, metadata=None):
|
||||
logger.info(f"Updating memory with {data=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -532,6 +578,7 @@ class Memory(MemoryBase):
|
||||
updated_at=new_metadata["updated_at"],
|
||||
)
|
||||
|
||||
|
||||
def _delete_memory(self, memory_id):
|
||||
logging.info(f"Deleting memory with {memory_id=}")
|
||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||
@@ -539,6 +586,7 @@ class Memory(MemoryBase):
|
||||
self.vector_store.delete(vector_id=memory_id)
|
||||
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the memory store.
|
||||
@@ -548,5 +596,6 @@ class Memory(MemoryBase):
|
||||
self.db.reset()
|
||||
capture_event("mem0.reset", self)
|
||||
|
||||
|
||||
def chat(self, query):
|
||||
raise NotImplementedError("Chat function not implemented yet.")
|
||||
|
||||
Reference in New Issue
Block a user