Modified the return statement for ADD call | Added tests to main.py and graph_memory.py (#1812)
This commit is contained in:
@@ -34,7 +34,6 @@
|
|||||||
|
|
||||||
<!-- Start of Selection -->
|
<!-- Start of Selection -->
|
||||||
<p style="display: flex;">
|
<p style="display: flex;">
|
||||||
<img src="https://media.tenor.com/K3j9pwWlME0AAAAi/fire-flame.gif" alt="Graph Memory Integration" style="width: 25px; margin-right: 10px;" />
|
|
||||||
<span style="font-size: 1.2em;">New Feature: Introducing Graph Memory. Check out our <a href="https://docs.mem0.ai/open-source/graph-memory" target="_blank">documentation</a>.</span>
|
<span style="font-size: 1.2em;">New Feature: Introducing Graph Memory. Check out our <a href="https://docs.mem0.ai/open-source/graph-memory" target="_blank">documentation</a>.</span>
|
||||||
</p>
|
</p>
|
||||||
<!-- End of Selection -->
|
<!-- End of Selection -->
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from mem0.llms.openai import OpenAILLM
|
|
||||||
|
|
||||||
UPDATE_GRAPH_PROMPT = """
|
UPDATE_GRAPH_PROMPT = """
|
||||||
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
|
You are an AI expert specializing in graph memory management and optimization. Your task is to analyze existing graph memories alongside new information, and update the relationships in the memory list to ensure the most accurate, current, and coherent representation of knowledge.
|
||||||
@@ -67,42 +66,3 @@ def get_update_memory_messages(existing_memories, memory):
|
|||||||
"content": get_update_memory_prompt(existing_memories, memory, UPDATE_GRAPH_PROMPT),
|
"content": get_update_memory_prompt(existing_memories, memory, UPDATE_GRAPH_PROMPT),
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_search_results(entities, query):
|
|
||||||
|
|
||||||
search_graph_prompt = f"""
|
|
||||||
You are an expert at searching through graph entity memories.
|
|
||||||
When provided with existing graph entities and a query, your task is to search through the provided graph entities to find the most relevant information from the graph entities related to the query.
|
|
||||||
The output should be from the graph entities only.
|
|
||||||
|
|
||||||
Here are the details of the task:
|
|
||||||
- Existing Graph Entities (source -> relationship -> target):
|
|
||||||
{entities}
|
|
||||||
|
|
||||||
- Query: {query}
|
|
||||||
|
|
||||||
The output should be from the graph entities only.
|
|
||||||
The output should be in the following JSON format:
|
|
||||||
{{
|
|
||||||
"search_results": [
|
|
||||||
{{
|
|
||||||
"source_node": "source_node",
|
|
||||||
"relationship": "relationship",
|
|
||||||
"target_node": "target_node"
|
|
||||||
}}
|
|
||||||
]
|
|
||||||
}}
|
|
||||||
"""
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": search_graph_prompt,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
llm = OpenAILLM()
|
|
||||||
|
|
||||||
results = llm.generate_response(messages=messages, response_format={"type": "json_object"})
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|||||||
@@ -101,6 +101,8 @@ class MemoryGraph:
|
|||||||
elif item['name'] == "noop":
|
elif item['name'] == "noop":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
returned_entities = []
|
||||||
|
|
||||||
for item in to_be_added:
|
for item in to_be_added:
|
||||||
source = item['source'].lower().replace(" ", "_")
|
source = item['source'].lower().replace(" ", "_")
|
||||||
source_type = item['source_type'].lower().replace(" ", "_")
|
source_type = item['source_type'].lower().replace(" ", "_")
|
||||||
@@ -108,6 +110,12 @@ class MemoryGraph:
|
|||||||
destination = item['destination'].lower().replace(" ", "_")
|
destination = item['destination'].lower().replace(" ", "_")
|
||||||
destination_type = item['destination_type'].lower().replace(" ", "_")
|
destination_type = item['destination_type'].lower().replace(" ", "_")
|
||||||
|
|
||||||
|
returned_entities.append({
|
||||||
|
"source" : source,
|
||||||
|
"relationship" : relation,
|
||||||
|
"target" : destination
|
||||||
|
})
|
||||||
|
|
||||||
# Create embeddings
|
# Create embeddings
|
||||||
source_embedding = self.embedding_model.embed(source)
|
source_embedding = self.embedding_model.embed(source)
|
||||||
dest_embedding = self.embedding_model.embed(destination)
|
dest_embedding = self.embedding_model.embed(destination)
|
||||||
@@ -137,6 +145,7 @@ class MemoryGraph:
|
|||||||
|
|
||||||
logger.info(f"Added {len(to_be_added)} new memories to the graph")
|
logger.info(f"Added {len(to_be_added)} new memories to the graph")
|
||||||
|
|
||||||
|
return returned_entities
|
||||||
|
|
||||||
def _search(self, query, filters):
|
def _search(self, query, filters):
|
||||||
_tools = [SEARCH_TOOL]
|
_tools = [SEARCH_TOOL]
|
||||||
@@ -155,8 +164,10 @@ class MemoryGraph:
|
|||||||
|
|
||||||
for item in search_results['tool_calls']:
|
for item in search_results['tool_calls']:
|
||||||
if item['name'] == "search":
|
if item['name'] == "search":
|
||||||
node_list.extend(item['arguments']['nodes'])
|
try:
|
||||||
relation_list.extend(item['arguments']['relations'])
|
node_list.extend(item['arguments']['nodes'])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in search tool: {e}")
|
||||||
|
|
||||||
node_list = list(set(node_list))
|
node_list = list(set(node_list))
|
||||||
relation_list = list(set(relation_list))
|
relation_list = list(set(relation_list))
|
||||||
@@ -228,8 +239,8 @@ class MemoryGraph:
|
|||||||
for item in reranked_results:
|
for item in reranked_results:
|
||||||
search_results.append({
|
search_results.append({
|
||||||
"source": item[0],
|
"source": item[0],
|
||||||
"relation": item[1],
|
"relationship": item[1],
|
||||||
"destination": item[2]
|
"target": item[2]
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"Returned {len(search_results)} search results")
|
logger.info(f"Returned {len(search_results)} search results")
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import concurrent
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -11,14 +10,14 @@ from typing import Any, Dict
|
|||||||
import pytz
|
import pytz
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
|
||||||
from mem0.configs.prompts import get_update_memory_messages
|
from mem0.configs.prompts import get_update_memory_messages
|
||||||
from mem0.memory.base import MemoryBase
|
from mem0.memory.base import MemoryBase
|
||||||
from mem0.memory.setup import setup_config
|
from mem0.memory.setup import setup_config
|
||||||
from mem0.memory.storage import SQLiteManager
|
from mem0.memory.storage import SQLiteManager
|
||||||
from mem0.memory.telemetry import capture_event
|
from mem0.memory.telemetry import capture_event
|
||||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages
|
||||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
from mem0.utils.factory import LlmFactory, EmbedderFactory, VectorStoreFactory
|
||||||
|
from mem0.configs.base import MemoryItem, MemoryConfig
|
||||||
|
|
||||||
# Setup user config
|
# Setup user config
|
||||||
setup_config()
|
setup_config()
|
||||||
@@ -49,6 +48,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
capture_event("mem0.init", self)
|
capture_event("mem0.init", self)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config_dict: Dict[str, Any]):
|
def from_config(cls, config_dict: Dict[str, Any]):
|
||||||
try:
|
try:
|
||||||
@@ -58,6 +58,7 @@ class Memory(MemoryBase):
|
|||||||
raise
|
raise
|
||||||
return cls(config)
|
return cls(config)
|
||||||
|
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
messages,
|
messages,
|
||||||
@@ -81,7 +82,7 @@ class Memory(MemoryBase):
|
|||||||
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
|
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Memory addition operation message.
|
dict: A dictionary containing the result of the memory addition operation.
|
||||||
"""
|
"""
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
@@ -102,17 +103,31 @@ class Memory(MemoryBase):
|
|||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
thread1 = threading.Thread(target=self._add_to_vector_store, args=(messages, metadata, filters))
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
thread2 = threading.Thread(target=self._add_to_graph, args=(messages, filters))
|
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
|
||||||
|
future2 = executor.submit(self._add_to_graph, messages, filters)
|
||||||
|
|
||||||
thread1.start()
|
concurrent.futures.wait([future1, future2])
|
||||||
thread2.start()
|
|
||||||
|
|
||||||
thread1.join()
|
vector_store_result = future1.result()
|
||||||
thread2.join()
|
graph_result = future2.result()
|
||||||
|
|
||||||
|
if self.version == "v1.1":
|
||||||
|
return {
|
||||||
|
"results" : vector_store_result,
|
||||||
|
"relations" : graph_result,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"The current add API output format is deprecated. "
|
||||||
|
"To use the latest format, set `api_version='v1.1'`. "
|
||||||
|
"The current format will be removed in mem0ai 1.1.0 and later versions.",
|
||||||
|
category=DeprecationWarning,
|
||||||
|
stacklevel=2
|
||||||
|
)
|
||||||
|
return {"message": "ok"}
|
||||||
|
|
||||||
|
|
||||||
return {"message": "ok"}
|
|
||||||
|
|
||||||
def _add_to_vector_store(self, messages, metadata, filters):
|
def _add_to_vector_store(self, messages, metadata, filters):
|
||||||
parsed_messages = parse_messages(messages)
|
parsed_messages = parse_messages(messages)
|
||||||
|
|
||||||
@@ -151,16 +166,30 @@ class Memory(MemoryBase):
|
|||||||
)
|
)
|
||||||
new_memories_with_actions = json.loads(new_memories_with_actions)
|
new_memories_with_actions = json.loads(new_memories_with_actions)
|
||||||
|
|
||||||
|
returned_memories = []
|
||||||
try:
|
try:
|
||||||
for resp in new_memories_with_actions["memory"]:
|
for resp in new_memories_with_actions["memory"]:
|
||||||
logging.info(resp)
|
logging.info(resp)
|
||||||
try:
|
try:
|
||||||
if resp["event"] == "ADD":
|
if resp["event"] == "ADD":
|
||||||
self._create_memory(data=resp["text"], metadata=metadata)
|
memory_id = self._create_memory(data=resp["text"], metadata=metadata)
|
||||||
|
returned_memories.append({
|
||||||
|
"memory" : resp["text"],
|
||||||
|
"event" : resp["event"],
|
||||||
|
})
|
||||||
elif resp["event"] == "UPDATE":
|
elif resp["event"] == "UPDATE":
|
||||||
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
self._update_memory(memory_id=resp["id"], data=resp["text"], metadata=metadata)
|
||||||
|
returned_memories.append({
|
||||||
|
"memory" : resp["text"],
|
||||||
|
"event" : resp["event"],
|
||||||
|
"previous_memory" : resp["old_memory"],
|
||||||
|
})
|
||||||
elif resp["event"] == "DELETE":
|
elif resp["event"] == "DELETE":
|
||||||
self._delete_memory(memory_id=resp["id"])
|
self._delete_memory(memory_id=resp["id"])
|
||||||
|
returned_memories.append({
|
||||||
|
"memory" : resp["text"],
|
||||||
|
"event" : resp["event"],
|
||||||
|
})
|
||||||
elif resp["event"] == "NONE":
|
elif resp["event"] == "NONE":
|
||||||
logging.info("NOOP for Memory.")
|
logging.info("NOOP for Memory.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -170,7 +199,11 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
capture_event("mem0.add", self)
|
capture_event("mem0.add", self)
|
||||||
|
|
||||||
|
return returned_memories
|
||||||
|
|
||||||
|
|
||||||
def _add_to_graph(self, messages, filters):
|
def _add_to_graph(self, messages, filters):
|
||||||
|
added_entities = []
|
||||||
if self.version == "v1.1" and self.enable_graph:
|
if self.version == "v1.1" and self.enable_graph:
|
||||||
if filters["user_id"]:
|
if filters["user_id"]:
|
||||||
self.graph.user_id = filters["user_id"]
|
self.graph.user_id = filters["user_id"]
|
||||||
@@ -179,6 +212,9 @@ class Memory(MemoryBase):
|
|||||||
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
|
||||||
self.graph.add(data, filters)
|
self.graph.add(data, filters)
|
||||||
|
|
||||||
|
return added_entities
|
||||||
|
|
||||||
|
|
||||||
def get(self, memory_id):
|
def get(self, memory_id):
|
||||||
"""
|
"""
|
||||||
Retrieve a memory by ID.
|
Retrieve a memory by ID.
|
||||||
@@ -229,6 +265,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
def get_all(self, user_id=None, agent_id=None, run_id=None, limit=100):
|
||||||
"""
|
"""
|
||||||
List all memories.
|
List all memories.
|
||||||
@@ -255,9 +292,9 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
if self.version == "v1.1":
|
if self.version == "v1.1":
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
return {"memories": all_memories, "entities": graph_entities}
|
return {"results": all_memories, "relations": graph_entities}
|
||||||
else:
|
else:
|
||||||
return {"memories": all_memories}
|
return {"results": all_memories}
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The current get_all API output format is deprecated. "
|
"The current get_all API output format is deprecated. "
|
||||||
@@ -268,6 +305,7 @@ class Memory(MemoryBase):
|
|||||||
)
|
)
|
||||||
return all_memories
|
return all_memories
|
||||||
|
|
||||||
|
|
||||||
def _get_all_from_vector_store(self, filters, limit):
|
def _get_all_from_vector_store(self, filters, limit):
|
||||||
memories = self.vector_store.list(filters=filters, limit=limit)
|
memories = self.vector_store.list(filters=filters, limit=limit)
|
||||||
|
|
||||||
@@ -302,6 +340,7 @@ class Memory(MemoryBase):
|
|||||||
]
|
]
|
||||||
return all_memories
|
return all_memories
|
||||||
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
self, query, user_id=None, agent_id=None, run_id=None, limit=100, filters=None
|
||||||
):
|
):
|
||||||
@@ -343,9 +382,9 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
if self.version == "v1.1":
|
if self.version == "v1.1":
|
||||||
if self.enable_graph:
|
if self.enable_graph:
|
||||||
return {"memories": original_memories, "entities": graph_entities}
|
return {"results": original_memories, "relations": graph_entities}
|
||||||
else:
|
else:
|
||||||
return {"memories" : original_memories}
|
return {"results" : original_memories}
|
||||||
else:
|
else:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The current get_all API output format is deprecated. "
|
"The current get_all API output format is deprecated. "
|
||||||
@@ -356,6 +395,7 @@ class Memory(MemoryBase):
|
|||||||
)
|
)
|
||||||
return original_memories
|
return original_memories
|
||||||
|
|
||||||
|
|
||||||
def _search_vector_store(self, query, filters, limit):
|
def _search_vector_store(self, query, filters, limit):
|
||||||
embeddings = self.embedding_model.embed(query)
|
embeddings = self.embedding_model.embed(query)
|
||||||
memories = self.vector_store.search(
|
memories = self.vector_store.search(
|
||||||
@@ -404,6 +444,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
return original_memories
|
return original_memories
|
||||||
|
|
||||||
|
|
||||||
def update(self, memory_id, data):
|
def update(self, memory_id, data):
|
||||||
"""
|
"""
|
||||||
Update a memory by ID.
|
Update a memory by ID.
|
||||||
@@ -419,6 +460,7 @@ class Memory(MemoryBase):
|
|||||||
self._update_memory(memory_id, data)
|
self._update_memory(memory_id, data)
|
||||||
return {"message": "Memory updated successfully!"}
|
return {"message": "Memory updated successfully!"}
|
||||||
|
|
||||||
|
|
||||||
def delete(self, memory_id):
|
def delete(self, memory_id):
|
||||||
"""
|
"""
|
||||||
Delete a memory by ID.
|
Delete a memory by ID.
|
||||||
@@ -430,6 +472,7 @@ class Memory(MemoryBase):
|
|||||||
self._delete_memory(memory_id)
|
self._delete_memory(memory_id)
|
||||||
return {"message": "Memory deleted successfully!"}
|
return {"message": "Memory deleted successfully!"}
|
||||||
|
|
||||||
|
|
||||||
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
def delete_all(self, user_id=None, agent_id=None, run_id=None):
|
||||||
"""
|
"""
|
||||||
Delete all memories.
|
Delete all memories.
|
||||||
@@ -464,6 +507,7 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
return {'message': 'Memories deleted successfully!'}
|
return {'message': 'Memories deleted successfully!'}
|
||||||
|
|
||||||
|
|
||||||
def history(self, memory_id):
|
def history(self, memory_id):
|
||||||
"""
|
"""
|
||||||
Get the history of changes for a memory by ID.
|
Get the history of changes for a memory by ID.
|
||||||
@@ -477,6 +521,7 @@ class Memory(MemoryBase):
|
|||||||
capture_event("mem0.history", self, {"memory_id": memory_id})
|
capture_event("mem0.history", self, {"memory_id": memory_id})
|
||||||
return self.db.get_history(memory_id)
|
return self.db.get_history(memory_id)
|
||||||
|
|
||||||
|
|
||||||
def _create_memory(self, data, metadata=None):
|
def _create_memory(self, data, metadata=None):
|
||||||
logging.info(f"Creating memory with {data=}")
|
logging.info(f"Creating memory with {data=}")
|
||||||
embeddings = self.embedding_model.embed(data)
|
embeddings = self.embedding_model.embed(data)
|
||||||
@@ -496,6 +541,7 @@ class Memory(MemoryBase):
|
|||||||
)
|
)
|
||||||
return memory_id
|
return memory_id
|
||||||
|
|
||||||
|
|
||||||
def _update_memory(self, memory_id, data, metadata=None):
|
def _update_memory(self, memory_id, data, metadata=None):
|
||||||
logger.info(f"Updating memory with {data=}")
|
logger.info(f"Updating memory with {data=}")
|
||||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||||
@@ -532,6 +578,7 @@ class Memory(MemoryBase):
|
|||||||
updated_at=new_metadata["updated_at"],
|
updated_at=new_metadata["updated_at"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _delete_memory(self, memory_id):
|
def _delete_memory(self, memory_id):
|
||||||
logging.info(f"Deleting memory with {memory_id=}")
|
logging.info(f"Deleting memory with {memory_id=}")
|
||||||
existing_memory = self.vector_store.get(vector_id=memory_id)
|
existing_memory = self.vector_store.get(vector_id=memory_id)
|
||||||
@@ -539,6 +586,7 @@ class Memory(MemoryBase):
|
|||||||
self.vector_store.delete(vector_id=memory_id)
|
self.vector_store.delete(vector_id=memory_id)
|
||||||
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
self.db.add_history(memory_id, prev_value, None, "DELETE", is_deleted=1)
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Reset the memory store.
|
Reset the memory store.
|
||||||
@@ -548,5 +596,6 @@ class Memory(MemoryBase):
|
|||||||
self.db.reset()
|
self.db.reset()
|
||||||
capture_event("mem0.reset", self)
|
capture_event("mem0.reset", self)
|
||||||
|
|
||||||
|
|
||||||
def chat(self, query):
|
def chat(self, query):
|
||||||
raise NotImplementedError("Chat function not implemented yet.")
|
raise NotImplementedError("Chat function not implemented yet.")
|
||||||
|
|||||||
219
tests/test_main.py
Normal file
219
tests/test_main.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from mem0.memory.main import Memory
|
||||||
|
from mem0.configs.base import MemoryConfig
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_openai():
|
||||||
|
os.environ['OPENAI_API_KEY'] = "123"
|
||||||
|
with patch('openai.OpenAI') as mock:
|
||||||
|
mock.return_value = Mock()
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def memory_instance():
|
||||||
|
with patch('mem0.utils.factory.EmbedderFactory') as mock_embedder, \
|
||||||
|
patch('mem0.utils.factory.VectorStoreFactory') as mock_vector_store, \
|
||||||
|
patch('mem0.utils.factory.LlmFactory') as mock_llm, \
|
||||||
|
patch('mem0.memory.telemetry.capture_event'), \
|
||||||
|
patch('mem0.memory.graph_memory.MemoryGraph'):
|
||||||
|
mock_embedder.create.return_value = Mock()
|
||||||
|
mock_vector_store.create.return_value = Mock()
|
||||||
|
mock_llm.create.return_value = Mock()
|
||||||
|
|
||||||
|
config = MemoryConfig(version="v1.1")
|
||||||
|
config.graph_store.config = {"some_config": "value"}
|
||||||
|
return Memory(config)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("version, enable_graph", [
|
||||||
|
("v1.0", False),
|
||||||
|
("v1.1", True)
|
||||||
|
])
|
||||||
|
def test_add(memory_instance, version, enable_graph):
|
||||||
|
memory_instance.config.version = version
|
||||||
|
memory_instance.enable_graph = enable_graph
|
||||||
|
memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}])
|
||||||
|
memory_instance._add_to_graph = Mock(return_value=[])
|
||||||
|
|
||||||
|
result = memory_instance.add(
|
||||||
|
messages=[{"role": "user", "content": "Test message"}],
|
||||||
|
user_id="test_user"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "results" in result
|
||||||
|
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
|
||||||
|
assert "relations" in result
|
||||||
|
assert result["relations"] == []
|
||||||
|
|
||||||
|
memory_instance._add_to_vector_store.assert_called_once_with(
|
||||||
|
[{"role": "user", "content": "Test message"}],
|
||||||
|
{"user_id": "test_user"},
|
||||||
|
{"user_id": "test_user"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove the conditional assertion for _add_to_graph
|
||||||
|
memory_instance._add_to_graph.assert_called_once_with(
|
||||||
|
[{"role": "user", "content": "Test message"}],
|
||||||
|
{"user_id": "test_user"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get(memory_instance):
|
||||||
|
mock_memory = Mock(id="test_id", payload={
|
||||||
|
"data": "Test memory",
|
||||||
|
"user_id": "test_user",
|
||||||
|
"hash": "test_hash",
|
||||||
|
"created_at": "2023-01-01T00:00:00",
|
||||||
|
"updated_at": "2023-01-02T00:00:00",
|
||||||
|
"extra_field": "extra_value"
|
||||||
|
})
|
||||||
|
memory_instance.vector_store.get = Mock(return_value=mock_memory)
|
||||||
|
|
||||||
|
result = memory_instance.get("test_id")
|
||||||
|
|
||||||
|
assert result["id"] == "test_id"
|
||||||
|
assert result["memory"] == "Test memory"
|
||||||
|
assert result["user_id"] == "test_user"
|
||||||
|
assert result["hash"] == "test_hash"
|
||||||
|
assert result["created_at"] == "2023-01-01T00:00:00"
|
||||||
|
assert result["updated_at"] == "2023-01-02T00:00:00"
|
||||||
|
assert result["metadata"] == {"extra_field": "extra_value"}
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("version, enable_graph", [
|
||||||
|
("v1.0", False),
|
||||||
|
("v1.1", True)
|
||||||
|
])
|
||||||
|
def test_search(memory_instance, version, enable_graph):
|
||||||
|
memory_instance.config.version = version
|
||||||
|
memory_instance.enable_graph = enable_graph
|
||||||
|
mock_memories = [
|
||||||
|
Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"}, score=0.9),
|
||||||
|
Mock(id="2", payload={"data": "Memory 2", "user_id": "test_user"}, score=0.8)
|
||||||
|
]
|
||||||
|
memory_instance.vector_store.search = Mock(return_value=mock_memories)
|
||||||
|
memory_instance.embedding_model.embed = Mock(return_value=[0.1, 0.2, 0.3])
|
||||||
|
memory_instance.graph.search = Mock(return_value=[{"relation": "test_relation"}])
|
||||||
|
|
||||||
|
result = memory_instance.search("test query", user_id="test_user")
|
||||||
|
|
||||||
|
if version == "v1.1":
|
||||||
|
assert "results" in result
|
||||||
|
assert len(result["results"]) == 2
|
||||||
|
assert result["results"][0]["id"] == "1"
|
||||||
|
assert result["results"][0]["memory"] == "Memory 1"
|
||||||
|
assert result["results"][0]["user_id"] == "test_user"
|
||||||
|
assert result["results"][0]["score"] == 0.9
|
||||||
|
if enable_graph:
|
||||||
|
assert "relations" in result
|
||||||
|
assert result["relations"] == [{"relation": "test_relation"}]
|
||||||
|
else:
|
||||||
|
assert "relations" not in result
|
||||||
|
else:
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "results" in result
|
||||||
|
assert len(result["results"]) == 2
|
||||||
|
assert result["results"][0]["id"] == "1"
|
||||||
|
assert result["results"][0]["memory"] == "Memory 1"
|
||||||
|
assert result["results"][0]["user_id"] == "test_user"
|
||||||
|
assert result["results"][0]["score"] == 0.9
|
||||||
|
|
||||||
|
memory_instance.vector_store.search.assert_called_once_with(
|
||||||
|
query=[0.1, 0.2, 0.3],
|
||||||
|
limit=100,
|
||||||
|
filters={"user_id": "test_user"}
|
||||||
|
)
|
||||||
|
memory_instance.embedding_model.embed.assert_called_once_with("test query")
|
||||||
|
|
||||||
|
if enable_graph:
|
||||||
|
memory_instance.graph.search.assert_called_once_with("test query", {"user_id": "test_user"})
|
||||||
|
else:
|
||||||
|
memory_instance.graph.search.assert_not_called()
|
||||||
|
|
||||||
|
def test_update(memory_instance):
|
||||||
|
memory_instance._update_memory = Mock()
|
||||||
|
|
||||||
|
result = memory_instance.update("test_id", "Updated memory")
|
||||||
|
|
||||||
|
memory_instance._update_memory.assert_called_once_with("test_id", "Updated memory")
|
||||||
|
assert result["message"] == "Memory updated successfully!"
|
||||||
|
|
||||||
|
def test_delete(memory_instance):
|
||||||
|
memory_instance._delete_memory = Mock()
|
||||||
|
|
||||||
|
result = memory_instance.delete("test_id")
|
||||||
|
|
||||||
|
memory_instance._delete_memory.assert_called_once_with("test_id")
|
||||||
|
assert result["message"] == "Memory deleted successfully!"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("version, enable_graph", [
|
||||||
|
("v1.0", False),
|
||||||
|
("v1.1", True)
|
||||||
|
])
|
||||||
|
def test_delete_all(memory_instance, version, enable_graph):
|
||||||
|
memory_instance.config.version = version
|
||||||
|
memory_instance.enable_graph = enable_graph
|
||||||
|
mock_memories = [Mock(id="1"), Mock(id="2")]
|
||||||
|
memory_instance.vector_store.list = Mock(return_value=(mock_memories, None))
|
||||||
|
memory_instance._delete_memory = Mock()
|
||||||
|
memory_instance.graph.delete_all = Mock()
|
||||||
|
|
||||||
|
result = memory_instance.delete_all(user_id="test_user")
|
||||||
|
|
||||||
|
assert memory_instance._delete_memory.call_count == 2
|
||||||
|
|
||||||
|
if enable_graph:
|
||||||
|
memory_instance.graph.delete_all.assert_called_once_with({"user_id": "test_user"})
|
||||||
|
else:
|
||||||
|
memory_instance.graph.delete_all.assert_not_called()
|
||||||
|
|
||||||
|
assert result["message"] == "Memories deleted successfully!"
|
||||||
|
|
||||||
|
def test_reset(memory_instance):
|
||||||
|
memory_instance.vector_store.delete_col = Mock()
|
||||||
|
memory_instance.db.reset = Mock()
|
||||||
|
|
||||||
|
memory_instance.reset()
|
||||||
|
|
||||||
|
memory_instance.vector_store.delete_col.assert_called_once()
|
||||||
|
memory_instance.db.reset.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("version, enable_graph, expected_result", [
|
||||||
|
("v1.0", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
|
||||||
|
("v1.1", False, {"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}]}),
|
||||||
|
("v1.1", True, {
|
||||||
|
"results": [{"id": "1", "memory": "Memory 1", "user_id": "test_user"}],
|
||||||
|
"relations": [{"source": "entity1", "relationship": "rel", "target": "entity2"}]
|
||||||
|
})
|
||||||
|
])
|
||||||
|
def test_get_all(memory_instance, version, enable_graph, expected_result):
|
||||||
|
memory_instance.config.version = version
|
||||||
|
memory_instance.enable_graph = enable_graph
|
||||||
|
mock_memories = [Mock(id="1", payload={"data": "Memory 1", "user_id": "test_user"})]
|
||||||
|
memory_instance.vector_store.list = Mock(return_value=(mock_memories, None))
|
||||||
|
memory_instance.graph.get_all = Mock(return_value=[
|
||||||
|
{"source": "entity1", "relationship": "rel", "target": "entity2"}
|
||||||
|
])
|
||||||
|
|
||||||
|
result = memory_instance.get_all(user_id="test_user")
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "results" in result
|
||||||
|
assert len(result["results"]) == len(expected_result["results"])
|
||||||
|
for expected_item, result_item in zip(expected_result["results"], result["results"]):
|
||||||
|
assert all(key in result_item for key in expected_item)
|
||||||
|
assert result_item["id"] == expected_item["id"]
|
||||||
|
assert result_item["memory"] == expected_item["memory"]
|
||||||
|
assert result_item["user_id"] == expected_item["user_id"]
|
||||||
|
|
||||||
|
if enable_graph:
|
||||||
|
assert "relations" in result
|
||||||
|
assert result["relations"] == expected_result["relations"]
|
||||||
|
else:
|
||||||
|
assert "relations" not in result
|
||||||
|
|
||||||
|
memory_instance.vector_store.list.assert_called_once_with(filters={"user_id": "test_user"}, limit=100)
|
||||||
|
|
||||||
|
if enable_graph:
|
||||||
|
memory_instance.graph.get_all.assert_called_once_with({"user_id": "test_user"})
|
||||||
|
else:
|
||||||
|
memory_instance.graph.get_all.assert_not_called()
|
||||||
Reference in New Issue
Block a user