Modified the return statement for ADD call | Added tests to main.py and graph_memory.py (#1812)

This commit is contained in:
Prateek Chhikara
2024-09-09 10:04:11 -07:00
committed by GitHub
parent 58f29d8781
commit b081e43b8d
5 changed files with 300 additions and 62 deletions

View File

@@ -34,7 +34,6 @@
<!-- Start of Selection -->
<p style="display: flex;">
<img src="https://media.tenor.com/K3j9pwWlME0AAAAi/fire-flame.gif" alt="Graph Memory Integration" style="width: 25px; margin-right: 10px;" />
<span style="font-size: 1.2em;">New Feature: Introducing Graph Memory. Check out our <a href="https://docs.mem0.ai/open-source/graph-memory" target="_blank">documentation</a>.</span>
</p>
<!-- End of Selection -->

View File

@@ -1,4 +1,3 @@
from mem0.llms.openai import OpenAILLM
UPDATE_GRAPH_PROMPT = """
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
@@ -67,42 +66,3 @@ def get_update_memory_messages(existing_memories, memory):
"content": get_update_memory_prompt(existing_memories, memory, UPDATE_GRAPH_PROMPT),
},
]
def get_search_results(entities, query):
search_graph_prompt = f"""
You are an expert at searching through graph entity memories.
When provided with existing graph entities and a query, your task is to search through the provided graph entities to find the most relevant information from the graph entities related to the query.
The output should be from the graph entities only.
Here are the details of the task:
- Existing Graph Entities (source -> relationship -> target):
{entities}
- Query: {query}
The output should be from the graph entities only.
The output should be in the following JSON format:
{{
"search_results": [
{{
"source_node": "source_node",
"relationship": "relationship",
"target_node": "target_node"
}}
]
}}
"""
messages = [
{
"role": "user",
"content": search_graph_prompt,
}
]
llm = OpenAILLM()
results = llm.generate_response(messages=messages, response_format={"type": "json_object"})
return results

View File

@@ -101,6 +101,8 @@ class MemoryGraph:
elif item['name'] == "noop":
continue
returned_entities = []
for item in to_be_added:
source = item['source'].lower().replace(" ", "_")
source_type = item['source_type'].lower().replace(" ", "_")
@@ -108,6 +110,12 @@ class MemoryGraph:
destination = item['destination'].lower().replace(" ", "_")
destination_type = item['destination_type'].lower().replace(" ", "_")
returned_entities.append({
"source" : source,
"relationship" : relation,
"target" : destination
})
# Create embeddings
source_embedding = self.embedding_model.embed(source)
dest_embedding = self.embedding_model.embed(destination)
@@ -137,6 +145,7 @@ class MemoryGraph:
logger.info(f"Added {len(to_be_added)} new memories to the graph")
return returned_entities
def _search(self, query, filters):
_tools = [SEARCH_TOOL]
@@ -155,8 +164,10 @@ class MemoryGraph:
for item in search_results['tool_calls']:
if item['name'] == "search":
node_list.extend(item['arguments']['nodes'])
relation_list.extend(item['arguments']['relations'])
try:
node_list.extend(item['arguments']['nodes'])
except Exception as e:
logger.error(f"Error in search tool: {e}")
node_list = list(set(node_list))
relation_list = list(set(relation_list))
@@ -228,8 +239,8 @@ class MemoryGraph:
for item in reranked_results:
search_results.append({
"source": item[0],
"relation": item[1],
"destination": item[2]
"relationship": item[1],
"target": item[2]
})
logger.info(f"Returned {len(search_results)} search results")

View File

@@ -2,7 +2,6 @@ import concurrent
import hashlib
import json
import logging
import threading
import uuid
import warnings
from datetime import datetime
@@ -11,14 +10,14 @@ from typing import Any, Dict
import pytz
from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
from mem0.configs.base import MemoryItem, MemoryConfig
# Setup user config
setup_config()
@@ -49,6 +48,7 @@ class Memory(MemoryBase):
capture_event("mem0.init", self)
@classmethod
def from_config(cls, config_dict: Dict[str, Any]):
try:
@@ -58,6 +58,7 @@ class Memory(MemoryBase):
raise
return cls(config)
def add(
self,
messages,
@@ -81,7 +82,7 @@ class Memory(MemoryBase):
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
Returns:
dict: Memory addition operation message.
dict: A dictionary containing the result of the memory addition operation.
"""
if metadata is None:
metadata = {}
@@ -102,17 +103,31 @@ class Memory(MemoryBase):
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future2 = executor.submit(self._add_to_graph, messages, filters)
thread1.start()
thread2.start()
concurrent.futures.wait([future1, future2])
thread1.join()
thread2.join()
vector_store_result = future1.result()
graph_result = future2.result()
if self.version == "v1.1":
return {
"results" : vector_store_result,
"relations" : graph_result,
}
else:
warnings.warn(
"The current add API output format is deprecated. "
"To use the latest format, set `api_version='v1.1'`. "
"The current format will be removed in mem0ai 1.1.0 and later versions.",
category=DeprecationWarning,
stacklevel=2
)
return {"message": "ok"}
return {"message": "ok"}
def _add_to_vector_store(self, messages, metadata, filters):
parsed_messages = parse_messages(messages)
@@ -151,16 +166,30 @@ class Memory(MemoryBase):
)
new_memories_with_actions = json.loads(new_memories_with_actions)
returned_memories = []
try:
for resp in new_memories_with_actions["memory"]:
logging.info(resp)
try:
if resp["event"] == "ADD":
self._create_memory(data=resp["text"], metadata=metadata)
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
})
elif resp["event"] == "UPDATE":
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
"previous_memory" : resp["old_memory"],
})
elif resp["event"] == "DELETE":
self._delete_memory(memory_id=resp["id"])
returned_memories.append({
"memory" : resp["text"],
"event" : resp["event"],
})
elif resp["event"] == "NONE":
logging.info("NOOP for Memory.")
except Exception as e:
@@ -170,7 +199,11 @@ class Memory(MemoryBase):
capture_event("mem0.add", self)
return returned_memories
def _add_to_graph(self, messages, filters):
added_entities = []
if self.version == "v1.1" and self.enable_graph:
if filters["user_id"]:
self.graph.user_id = filters["user_id"]
@@ -179,6 +212,9 @@ class Memory(MemoryBase):
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
self.graph.add(data, filters)
return added_entities
def get(self, memory_id):
"""
Retrieve a memory by ID.
@@ -229,6 +265,7 @@ class Memory(MemoryBase):
return result
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
"""
List all memories.
@@ -255,9 +292,9 @@ class Memory(MemoryBase):
if self.version == "v1.1":
if self.enable_graph:
return {"memories": all_memories, "entities": graph_entities}
return {"results": all_memories, "relations": graph_entities}
else:
return {"memories": all_memories}
return {"results": all_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
@@ -268,6 +305,7 @@ class Memory(MemoryBase):
)
return all_memories
def _get_all_from_vector_store(self, filters, limit):
memories = self.vector_store.list(filters=filters, limit=limit)
@@ -302,6 +340,7 @@ class Memory(MemoryBase):
]
return all_memories
def search(
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
):
@@ -343,9 +382,9 @@ class Memory(MemoryBase):
if self.version == "v1.1":
if self.enable_graph:
return {"memories": original_memories, "entities": graph_entities}
return {"results": original_memories, "relations": graph_entities}
else:
return {"memories" : original_memories}
return {"results" : original_memories}
else:
warnings.warn(
"The current get_all API output format is deprecated. "
@@ -356,6 +395,7 @@ class Memory(MemoryBase):
)
return original_memories
def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query)
memories = self.vector_store.search(
@@ -404,6 +444,7 @@ class Memory(MemoryBase):
return original_memories
def update(self, memory_id, data):
"""
Update a memory by ID.
@@ -419,6 +460,7 @@ class Memory(MemoryBase):
self._update_memory(memory_id, data)
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
"""
Delete a memory by ID.
@@ -430,6 +472,7 @@ class Memory(MemoryBase):
self._delete_memory(memory_id)
return {"message": "Memory deleted successfully!"}
def delete_all(self, user_id=None, agent_id=None, run_id=None):
"""
Delete all memories.
@@ -464,6 +507,7 @@ class Memory(MemoryBase):
return {'message': 'Memories deleted successfully!'}
def history(self, memory_id):
"""
Get the history of changes for a memory by ID.
@@ -477,6 +521,7 @@ class Memory(MemoryBase):
capture_event("mem0.history", self, {"memory_id": memory_id})
return self.db.get_history(memory_id)
def _create_memory(self, data, metadata=None):
logging.info(f"Creating memory with {data=}")
embeddings = self.embedding_model.embed(data)
@@ -496,6 +541,7 @@ class Memory(MemoryBase):
)
return memory_id
def _update_memory(self, memory_id, data, metadata=None):
logger.info(f"Updating memory with {data=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -532,6 +578,7 @@ class Memory(MemoryBase):
updated_at=new_metadata["updated_at"],
)
def _delete_memory(self, memory_id):
logging.info(f"Deleting memory with {memory_id=}")
existing_memory = self.vector_store.get(vector_id=memory_id)
@@ -539,6 +586,7 @@ class Memory(MemoryBase):
self.vector_store.delete(vector_id=memory_id)
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
def reset(self):
"""
Reset the memory store.
@@ -548,5 +596,6 @@ class Memory(MemoryBase):
self.db.reset()
capture_event("mem0.reset", self)
def chat(self, query):
raise NotImplementedError("Chat function not implemented yet.")

219
tests/test_main.py Normal file
View File

@@ -0,0 +1,219 @@
import pytest
import os
from unittest.mock import Mock, patch
from mem0.memory.main import Memory
from mem0.configs.base import MemoryConfig
@pytest.fixture(autouse=True)
def mock_openai():
os.environ['OPENAI_API_KEY'] = "123"
with patch('openai.OpenAI') as mock:
mock.return_value = Mock()
yield mock
@pytest.fixture
def memory_instance():
with patch('mem0.utils.factory.EmbedderFactory') as mock_embedder, \
patch('mem0.utils.factory.VectorStoreFactory') as mock_vector_store, \
patch('mem0.utils.factory.LlmFactory') as mock_llm, \
patch('mem0.memory.telemetry.capture_event'), \
patch('mem0.memory.graph_memory.MemoryGraph'):
mock_embedder.create.return_value = Mock()
mock_vector_store.create.return_value = Mock()
mock_llm.create.return_value = Mock()
config = MemoryConfig(version="v1.1")
config.graph_store.config = {"some_config": "value"}
return Memory(config)
@pytest.mark.parametrize("version, enable_graph", [
("v1.0", False),
("v1.1", True)
])
def test_add(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}])
memory_instance._add_to_graph = Mock(return_value=[])
result = memory_instance.add(
messages=[{"role": "user", "content": "Test message"}],
user_id="test_user"
)
assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
assert "relations" in result
assert result["relations"] == []
memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}],
{"user_id": "test_user"},
{"user_id": "test_user"}
)
# Remove the conditional assertion for _add_to_graph
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test message"}],
{"user_id": "test_user"}
)
def test_get(memory_instance):
mock_memory = Mock(id="test_id", payload={
"data": "Test memory",
"user_id": "test_user",
"hash": "test_hash",
"created_at": "2023-01-01T00:00:00",
"updated_at": "2023-01-02T00:00:00",
"extra_field": "extra_value"
})
memory_instance.vector_store.get = Mock(return_value=mock_memory)
result = memory_instance.get("test_id")
assert result["id"] == "test_id"
assert result["memory"] == "Test memory"
assert result["user_id"] == "test_user"
assert result["hash"] == "test_hash"
assert result["created_at"] == "2023-01-01T00:00:00"
assert result["updated_at"] == "2023-01-02T00:00:00"
assert result["metadata"] == {"extra_field": "extra_value"}
@pytest.mark.parametrize("version, enable_graph", [
("v1.0", False),
("v1.1", True)
])
def test_search(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
mock_memories = [
Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"}, score=0.9),
Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8)
]
memory_instance.vector_store.search = Mock(return_value=mock_memories)
memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3])
memory_instance.graph.search = Mock(return_value=[{"relation": "test_relation"}])
result = memory_instance.search("test query", user_id="test_user")
if version == "v1.1":
assert "results" in result
assert len(result["results"]) == 2
assert result["results"][0]["id"] == "1"
assert result["results"][0]["memory"] == "Memory 1"
assert result["results"][0]["user_id"] == "test_user"
assert result["results"][0]["score"] == 0.9
if enable_graph:
assert "relations" in result
assert result["relations"] == [{"relation": "test_relation"}]
else:
assert "relations" not in result
else:
assert isinstance(result, dict)
assert "results" in result
assert len(result["results"]) == 2
assert result["results"][0]["id"] == "1"
assert result["results"][0]["memory"] == "Memory 1"
assert result["results"][0]["user_id"] == "test_user"
assert result["results"][0]["score"] == 0.9
memory_instance.vector_store.search.assert_called_once_with(
query=[0.1, 0.2, 0.3],
limit=100,
filters={"user_id": "test_user"}
)
memory_instance.embedding_model.embed.assert_called_once_with("test query")
if enable_graph:
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"})
else:
memory_instance.graph.search.assert_not_called()
def test_update(memory_instance):
memory_instance._update_memory = Mock()
result = memory_instance.update("test_id", "Updated memory")
memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory")
assert result["message"] == "Memory updated successfully!"
def test_delete(memory_instance):
memory_instance._delete_memory = Mock()
result = memory_instance.delete("test_id")
memory_instance._delete_memory.assert_called_once_with("test_id")
assert result["message"] == "Memory deleted successfully!"
@pytest.mark.parametrize("version, enable_graph", [
("v1.0", False),
("v1.1", True)
])
def test_delete_all(memory_instance, version, enable_graph):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
mock_memories = [Mock(id="1"), Mock(id="2")]
memory_instance.vector_store.list = Mock(return_value=(mock_memories, None))
memory_instance._delete_memory = Mock()
memory_instance.graph.delete_all = Mock()
result = memory_instance.delete_all(user_id="test_user")
assert memory_instance._delete_memory.call_count == 2
if enable_graph:
memory_instance.graph.delete_all.assert_called_once_with({"user_id": "test_user"})
else:
memory_instance.graph.delete_all.assert_not_called()
assert result["message"] == "Memories deleted successfully!"
def test_reset(memory_instance):
memory_instance.vector_store.delete_col = Mock()
memory_instance.db.reset = Mock()
memory_instance.reset()
memory_instance.vector_store.delete_col.assert_called_once()
memory_instance.db.reset.assert_called_once()
@pytest.mark.parametrize("version, enable_graph, expected_result", [
("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
("v1.1", True, {
"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}],
"relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}]
})
])
def test_get_all(memory_instance, version, enable_graph, expected_result):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
mock_memories = [Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"})]
memory_instance.vector_store.list = Mock(return_value=(mock_memories, None))
memory_instance.graph.get_all = Mock(return_value=[
{"source": "entity1", "relationship": "rel", "target": "entity2"}
])
result = memory_instance.get_all(user_id="test_user")
assert isinstance(result, dict)
assert "results" in result
assert len(result["results"]) == len(expected_result["results"])
for expected_item, result_item in zip(expected_result["results"], result["results"]):
assert all(key in result_item for key in expected_item)
assert result_item["id"] == expected_item["id"]
assert result_item["memory"] == expected_item["memory"]
assert result_item["user_id"] == expected_item["user_id"]
if enable_graph:
assert "relations" in result
assert result["relations"] == expected_result["relations"]
else:
assert "relations" not in result
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
if enable_graph:
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"})
else:
memory_instance.graph.get_all.assert_not_called()