From b081e43b8de4c59091c86432e8993eda38bd9b4b Mon Sep 17 00:00:00 2001 From: Prateek Chhikara <46902268+prateekchhikara@users.noreply.github.com> Date: Mon, 9 Sep 2024 10:04:11 -0700 Subject: [PATCH] Modified the return statement for ADD call | Added tests to main.py and graph_memory.py (#1812) --- README.md | 1 - mem0/graphs/utils.py | 40 ------- mem0/memory/graph_memory.py | 19 +++- mem0/memory/main.py | 83 +++++++++++--- tests/test_main.py | 219 ++++++++++++++++++++++++++++++++++++ 5 files changed, 300 insertions(+), 62 deletions(-) create mode 100644 tests/test_main.py diff --git a/README.md b/README.md index b5f03a73..0364e1ec 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,6 @@

- Graph Memory Integration New Feature: Introducing Graph Memory. Check out our documentation.

diff --git a/mem0/graphs/utils.py b/mem0/graphs/utils.py index 4613952e..e9ed827e 100644 --- a/mem0/graphs/utils.py +++ b/mem0/graphs/utils.py @@ -1,4 +1,3 @@ -from mem0.llms.openai import OpenAILLM UPDATE_GRAPH_PROMPT = """ You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge. @@ -67,42 +66,3 @@ def get_update_memory_messages(existing_memories, memory): "content": get_update_memory_prompt(existing_memories, memory, UPDATE_GRAPH_PROMPT), }, ] - -def get_search_results(entities, query): - - search_graph_prompt = f""" -You are an expert at searching through graph entity memories. -When provided with existing graph entities and a query, your task is to search through the provided graph entities to find the most relevant information from the graph entities related to the query. -The output should be from the graph entities only. - -Here are the details of the task: -- Existing Graph Entities (source -> relationship -> target): -{entities} - -- Query: {query} - -The output should be from the graph entities only. -The output should be in the following JSON format: -{{ - "search_results": [ - {{ - "source_node": "source_node", - "relationship": "relationship", - "target_node": "target_node" - }} - ] -}} -""" - - messages = [ - { - "role": "user", - "content": search_graph_prompt, - } - ] - - llm = OpenAILLM() - - results = llm.generate_response(messages=messages, response_format={"type": "json_object"}) - - return results diff --git a/mem0/memory/graph_memory.py b/mem0/memory/graph_memory.py index d4f588bf..138c5119 100644 --- a/mem0/memory/graph_memory.py +++ b/mem0/memory/graph_memory.py @@ -101,6 +101,8 @@ class MemoryGraph: elif item['name'] == "noop": continue + returned_entities = [] + for item in to_be_added: source = item['source'].lower().replace(" ", "_") source_type = item['source_type'].lower().replace(" ", "_") @@ -108,6 +110,12 @@ class MemoryGraph: destination = item['destination'].lower().replace(" ", "_") destination_type = item['destination_type'].lower().replace(" ", "_") + returned_entities.append({ + "source" : source, + "relationship" : relation, + "target" : destination + }) + # Create embeddings source_embedding = self.embedding_model.embed(source) dest_embedding = self.embedding_model.embed(destination) @@ -137,6 +145,7 @@ class MemoryGraph: logger.info(f"Added {len(to_be_added)} new memories to the graph") + return returned_entities def _search(self, query, filters): _tools = [SEARCH_TOOL] @@ -155,8 +164,10 @@ class MemoryGraph: for item in search_results['tool_calls']: if item['name'] == "search": - node_list.extend(item['arguments']['nodes']) - relation_list.extend(item['arguments']['relations']) + try: + node_list.extend(item['arguments']['nodes']) + except Exception as e: + logger.error(f"Error in search tool: {e}") node_list = list(set(node_list)) relation_list = list(set(relation_list)) @@ -228,8 +239,8 @@ class MemoryGraph: for item in reranked_results: search_results.append({ "source": item[0], - "relation": item[1], - "destination": item[2] + "relationship": item[1], + "target": item[2] }) logger.info(f"Returned {len(search_results)} search results") diff --git a/mem0/memory/main.py b/mem0/memory/main.py index f158ec82..a3bb5024 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -2,7 +2,6 @@ import concurrent import hashlib import json import logging -import threading import uuid import warnings from datetime import datetime @@ -11,14 +10,14 @@ from typing import Any, Dict import pytz from pydantic import ValidationError -from mem0.configs.base import MemoryConfig, MemoryItem from mem0.configs.prompts import get_update_memory_messages from mem0.memory.base import MemoryBase from mem0.memory.setup import setup_config from mem0.memory.storage import SQLiteManager from mem0.memory.telemetry import capture_event from mem0.memory.utils import get_fact_retrieval_messages, parse_messages -from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory +from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory +from mem0.configs.base import MemoryItem, MemoryConfig # Setup user config setup_config() @@ -49,6 +48,7 @@ class Memory(MemoryBase): capture_event("mem0.init", self) + @classmethod def from_config(cls, config_dict: Dict[str, Any]): try: @@ -58,6 +58,7 @@ class Memory(MemoryBase): raise return cls(config) + def add( self, messages, @@ -81,7 +82,7 @@ class Memory(MemoryBase): prompt (str, optional): Prompt to use for memory deduction. Defaults to None. Returns: - dict: Memory addition operation message. + dict: A dictionary containing the result of the memory addition operation. """ if metadata is None: metadata = {} @@ -102,17 +103,31 @@ class Memory(MemoryBase): if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters)) - thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters)) + with concurrent.futures.ThreadPoolExecutor() as executor: + future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters) + future2 = executor.submit(self._add_to_graph, messages, filters) - thread1.start() - thread2.start() + concurrent.futures.wait([future1, future2]) - thread1.join() - thread2.join() + vector_store_result = future1.result() + graph_result = future2.result() + + if self.version == "v1.1": + return { + "results" : vector_store_result, + "relations" : graph_result, + } + else: + warnings.warn( + "The current add API output format is deprecated. " + "To use the latest format, set `api_version='v1.1'`. " + "The current format will be removed in mem0ai 1.1.0 and later versions.", + category=DeprecationWarning, + stacklevel=2 + ) + return {"message": "ok"} + - return {"message": "ok"} - def _add_to_vector_store(self, messages, metadata, filters): parsed_messages = parse_messages(messages) @@ -151,16 +166,30 @@ class Memory(MemoryBase): ) new_memories_with_actions = json.loads(new_memories_with_actions) + returned_memories = [] try: for resp in new_memories_with_actions["memory"]: logging.info(resp) try: if resp["event"] == "ADD": - self._create_memory(data=resp["text"], metadata=metadata) + memory_id = self._create_memory(data=resp["text"], metadata=metadata) + returned_memories.append({ + "memory" : resp["text"], + "event" : resp["event"], + }) elif resp["event"] == "UPDATE": self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata) + returned_memories.append({ + "memory" : resp["text"], + "event" : resp["event"], + "previous_memory" : resp["old_memory"], + }) elif resp["event"] == "DELETE": self._delete_memory(memory_id=resp["id"]) + returned_memories.append({ + "memory" : resp["text"], + "event" : resp["event"], + }) elif resp["event"] == "NONE": logging.info("NOOP for Memory.") except Exception as e: @@ -170,7 +199,11 @@ class Memory(MemoryBase): capture_event("mem0.add", self) + return returned_memories + + def _add_to_graph(self, messages, filters): + added_entities = [] if self.version == "v1.1" and self.enable_graph: if filters["user_id"]: self.graph.user_id = filters["user_id"] @@ -179,6 +212,9 @@ class Memory(MemoryBase): data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"]) self.graph.add(data, filters) + return added_entities + + def get(self, memory_id): """ Retrieve a memory by ID. @@ -229,6 +265,7 @@ class Memory(MemoryBase): return result + def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100): """ List all memories. @@ -255,9 +292,9 @@ class Memory(MemoryBase): if self.version == "v1.1": if self.enable_graph: - return {"memories": all_memories, "entities": graph_entities} + return {"results": all_memories, "relations": graph_entities} else: - return {"memories": all_memories} + return {"results": all_memories} else: warnings.warn( "The current get_all API output format is deprecated. " @@ -268,6 +305,7 @@ class Memory(MemoryBase): ) return all_memories + def _get_all_from_vector_store(self, filters, limit): memories = self.vector_store.list(filters=filters, limit=limit) @@ -302,6 +340,7 @@ class Memory(MemoryBase): ] return all_memories + def search( self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None ): @@ -343,9 +382,9 @@ class Memory(MemoryBase): if self.version == "v1.1": if self.enable_graph: - return {"memories": original_memories, "entities": graph_entities} + return {"results": original_memories, "relations": graph_entities} else: - return {"memories" : original_memories} + return {"results" : original_memories} else: warnings.warn( "The current get_all API output format is deprecated. " @@ -356,6 +395,7 @@ class Memory(MemoryBase): ) return original_memories + def _search_vector_store(self, query, filters, limit): embeddings = self.embedding_model.embed(query) memories = self.vector_store.search( @@ -404,6 +444,7 @@ class Memory(MemoryBase): return original_memories + def update(self, memory_id, data): """ Update a memory by ID. @@ -419,6 +460,7 @@ class Memory(MemoryBase): self._update_memory(memory_id, data) return {"message": "Memory updated successfully!"} + def delete(self, memory_id): """ Delete a memory by ID. @@ -430,6 +472,7 @@ class Memory(MemoryBase): self._delete_memory(memory_id) return {"message": "Memory deleted successfully!"} + def delete_all(self, user_id=None, agent_id=None, run_id=None): """ Delete all memories. @@ -464,6 +507,7 @@ class Memory(MemoryBase): return {'message': 'Memories deleted successfully!'} + def history(self, memory_id): """ Get the history of changes for a memory by ID. @@ -477,6 +521,7 @@ class Memory(MemoryBase): capture_event("mem0.history", self, {"memory_id": memory_id}) return self.db.get_history(memory_id) + def _create_memory(self, data, metadata=None): logging.info(f"Creating memory with {data=}") embeddings = self.embedding_model.embed(data) @@ -496,6 +541,7 @@ class Memory(MemoryBase): ) return memory_id + def _update_memory(self, memory_id, data, metadata=None): logger.info(f"Updating memory with {data=}") existing_memory = self.vector_store.get(vector_id=memory_id) @@ -532,6 +578,7 @@ class Memory(MemoryBase): updated_at=new_metadata["updated_at"], ) + def _delete_memory(self, memory_id): logging.info(f"Deleting memory with {memory_id=}") existing_memory = self.vector_store.get(vector_id=memory_id) @@ -539,6 +586,7 @@ class Memory(MemoryBase): self.vector_store.delete(vector_id=memory_id) self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1) + def reset(self): """ Reset the memory store. @@ -548,5 +596,6 @@ class Memory(MemoryBase): self.db.reset() capture_event("mem0.reset", self) + def chat(self, query): raise NotImplementedError("Chat function not implemented yet.") diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 00000000..16a672e3 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,219 @@ +import pytest +import os +from unittest.mock import Mock, patch +from mem0.memory.main import Memory +from mem0.configs.base import MemoryConfig + +@pytest.fixture(autouse=True) +def mock_openai(): + os.environ['OPENAI_API_KEY'] = "123" + with patch('openai.OpenAI') as mock: + mock.return_value = Mock() + yield mock + +@pytest.fixture +def memory_instance(): + with patch('mem0.utils.factory.EmbedderFactory') as mock_embedder, \ + patch('mem0.utils.factory.VectorStoreFactory') as mock_vector_store, \ + patch('mem0.utils.factory.LlmFactory') as mock_llm, \ + patch('mem0.memory.telemetry.capture_event'), \ + patch('mem0.memory.graph_memory.MemoryGraph'): + mock_embedder.create.return_value = Mock() + mock_vector_store.create.return_value = Mock() + mock_llm.create.return_value = Mock() + + config = MemoryConfig(version="v1.1") + config.graph_store.config = {"some_config": "value"} + return Memory(config) + +@pytest.mark.parametrize("version, enable_graph", [ + ("v1.0", False), + ("v1.1", True) +]) +def test_add(memory_instance, version, enable_graph): + memory_instance.config.version = version + memory_instance.enable_graph = enable_graph + memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}]) + memory_instance._add_to_graph = Mock(return_value=[]) + + result = memory_instance.add( + messages=[{"role": "user", "content": "Test message"}], + user_id="test_user" + ) + + assert "results" in result + assert result["results"] == [{"memory": "Test memory", "event": "ADD"}] + assert "relations" in result + assert result["relations"] == [] + + memory_instance._add_to_vector_store.assert_called_once_with( + [{"role": "user", "content": "Test message"}], + {"user_id": "test_user"}, + {"user_id": "test_user"} + ) + + # Remove the conditional assertion for _add_to_graph + memory_instance._add_to_graph.assert_called_once_with( + [{"role": "user", "content": "Test message"}], + {"user_id": "test_user"} + ) + +def test_get(memory_instance): + mock_memory = Mock(id="test_id", payload={ + "data": "Test memory", + "user_id": "test_user", + "hash": "test_hash", + "created_at": "2023-01-01T00:00:00", + "updated_at": "2023-01-02T00:00:00", + "extra_field": "extra_value" + }) + memory_instance.vector_store.get = Mock(return_value=mock_memory) + + result = memory_instance.get("test_id") + + assert result["id"] == "test_id" + assert result["memory"] == "Test memory" + assert result["user_id"] == "test_user" + assert result["hash"] == "test_hash" + assert result["created_at"] == "2023-01-01T00:00:00" + assert result["updated_at"] == "2023-01-02T00:00:00" + assert result["metadata"] == {"extra_field": "extra_value"} + +@pytest.mark.parametrize("version, enable_graph", [ + ("v1.0", False), + ("v1.1", True) +]) +def test_search(memory_instance, version, enable_graph): + memory_instance.config.version = version + memory_instance.enable_graph = enable_graph + mock_memories = [ + Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"}, score=0.9), + Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8) + ] + memory_instance.vector_store.search = Mock(return_value=mock_memories) + memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3]) + memory_instance.graph.search = Mock(return_value=[{"relation": "test_relation"}]) + + result = memory_instance.search("test query", user_id="test_user") + + if version == "v1.1": + assert "results" in result + assert len(result["results"]) == 2 + assert result["results"][0]["id"] == "1" + assert result["results"][0]["memory"] == "Memory 1" + assert result["results"][0]["user_id"] == "test_user" + assert result["results"][0]["score"] == 0.9 + if enable_graph: + assert "relations" in result + assert result["relations"] == [{"relation": "test_relation"}] + else: + assert "relations" not in result + else: + assert isinstance(result, dict) + assert "results" in result + assert len(result["results"]) == 2 + assert result["results"][0]["id"] == "1" + assert result["results"][0]["memory"] == "Memory 1" + assert result["results"][0]["user_id"] == "test_user" + assert result["results"][0]["score"] == 0.9 + + memory_instance.vector_store.search.assert_called_once_with( + query=[0.1, 0.2, 0.3], + limit=100, + filters={"user_id": "test_user"} + ) + memory_instance.embedding_model.embed.assert_called_once_with("test query") + + if enable_graph: + memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"}) + else: + memory_instance.graph.search.assert_not_called() + +def test_update(memory_instance): + memory_instance._update_memory = Mock() + + result = memory_instance.update("test_id", "Updated memory") + + memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory") + assert result["message"] == "Memory updated successfully!" + +def test_delete(memory_instance): + memory_instance._delete_memory = Mock() + + result = memory_instance.delete("test_id") + + memory_instance._delete_memory.assert_called_once_with("test_id") + assert result["message"] == "Memory deleted successfully!" + +@pytest.mark.parametrize("version, enable_graph", [ + ("v1.0", False), + ("v1.1", True) +]) +def test_delete_all(memory_instance, version, enable_graph): + memory_instance.config.version = version + memory_instance.enable_graph = enable_graph + mock_memories = [Mock(id="1"), Mock(id="2")] + memory_instance.vector_store.list = Mock(return_value=(mock_memories, None)) + memory_instance._delete_memory = Mock() + memory_instance.graph.delete_all = Mock() + + result = memory_instance.delete_all(user_id="test_user") + + assert memory_instance._delete_memory.call_count == 2 + + if enable_graph: + memory_instance.graph.delete_all.assert_called_once_with({"user_id": "test_user"}) + else: + memory_instance.graph.delete_all.assert_not_called() + + assert result["message"] == "Memories deleted successfully!" + +def test_reset(memory_instance): + memory_instance.vector_store.delete_col = Mock() + memory_instance.db.reset = Mock() + + memory_instance.reset() + + memory_instance.vector_store.delete_col.assert_called_once() + memory_instance.db.reset.assert_called_once() + +@pytest.mark.parametrize("version, enable_graph, expected_result", [ + ("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}), + ("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}), + ("v1.1", True, { + "results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}], + "relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}] + }) +]) +def test_get_all(memory_instance, version, enable_graph, expected_result): + memory_instance.config.version = version + memory_instance.enable_graph = enable_graph + mock_memories = [Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"})] + memory_instance.vector_store.list = Mock(return_value=(mock_memories, None)) + memory_instance.graph.get_all = Mock(return_value=[ + {"source": "entity1", "relationship": "rel", "target": "entity2"} + ]) + + result = memory_instance.get_all(user_id="test_user") + + assert isinstance(result, dict) + assert "results" in result + assert len(result["results"]) == len(expected_result["results"]) + for expected_item, result_item in zip(expected_result["results"], result["results"]): + assert all(key in result_item for key in expected_item) + assert result_item["id"] == expected_item["id"] + assert result_item["memory"] == expected_item["memory"] + assert result_item["user_id"] == expected_item["user_id"] + + if enable_graph: + assert "relations" in result + assert result["relations"] == expected_result["relations"] + else: + assert "relations" not in result + + memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100) + + if enable_graph: + memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"}) + else: + memory_instance.graph.get_all.assert_not_called()