Fix/new memories wrong type (#2635)

This commit is contained in:
John Lockwood
2025-05-07 05:05:24 -07:00
committed by GitHub
parent eb7f5a774c
commit 641be2878d
4 changed files with 200 additions and 21 deletions

View File

@@ -66,6 +66,8 @@ class Memory(MemoryBase):
self.graph = MemoryGraph(self.config)
self.enable_graph = True
else:
self.graph = None
self.config.vector_store.config.collection_name = "mem0migrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]:
@@ -263,20 +265,20 @@ class Memory(MemoryBase):
)
try:
new_memories_with_actions = self.llm.generate_response(
response: str = self.llm.generate_response(
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
new_memories_with_actions = []
logging.error(f"Error in new memory actions response: {e}")
response = ""
try:
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
new_memories_with_actions = json.loads(new_memories_with_actions)
response = remove_code_blocks(response)
new_memories_with_actions = json.loads(response)
except Exception as e:
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = []
new_memories_with_actions = {}
returned_memories = []
try:
@@ -767,13 +769,13 @@ class Memory(MemoryBase):
logger.warning("Resetting all memories")
# Close the old connection if possible
if hasattr(self.db, 'connection') and self.db.connection:
if hasattr(self.db, "connection") and self.db.connection:
self.db.connection.execute("DROP TABLE IF EXISTS history")
self.db.connection.close()
self.db = SQLiteManager(self.config.history_db_path)
if hasattr(self.vector_store, 'reset'):
if hasattr(self.vector_store, "reset"):
self.vector_store = VectorStoreFactory.reset(self.vector_store)
else:
logger.warning("Vector store does not support reset. Skipping.")
@@ -811,6 +813,8 @@ class AsyncMemory(MemoryBase):
self.graph = MemoryGraph(self.config)
self.enable_graph = True
else:
self.graph = None
capture_event("mem0.init", self, {"sync_type": "async"})
@@ -1007,21 +1011,21 @@ class AsyncMemory(MemoryBase):
)
try:
new_memories_with_actions = await asyncio.to_thread(
response: str = await asyncio.to_thread(
self.llm.generate_response,
messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"},
)
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
new_memories_with_actions = []
logging.error(f"Error in new memory actions response: {e}")
response = ""
try:
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
new_memories_with_actions = json.loads(new_memories_with_actions)
response = remove_code_blocks(response)
new_memories_with_actions = json.loads(response)
except Exception as e:
logging.error(f"Invalid JSON response: {e}")
new_memories_with_actions = []
new_memories_with_actions = {}
returned_memories = []
try:
@@ -1092,7 +1096,9 @@ class AsyncMemory(MemoryBase):
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"})
capture_event(
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
)
return returned_memories
@@ -1547,10 +1553,10 @@ class AsyncMemory(MemoryBase):
gc.collect()
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
if hasattr(self.vector_store, "client") and hasattr(self.vector_store.client, "close"):
await asyncio.to_thread(self.vector_store.client.close)
if hasattr(self.db, 'connection') and self.db.connection:
if hasattr(self.db, "connection") and self.db.connection:
await asyncio.to_thread(lambda: self.db.connection.execute("DROP TABLE IF EXISTS history"))
await asyncio.to_thread(self.db.connection.close)

39
poetry.lock generated
View File

@@ -1544,6 +1544,43 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.23.8"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
groups = ["test"]
files = [
{file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"},
{file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"},
]
[package.dependencies]
pytest = ">=7.0.0,<9"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-mock"
version = "3.14.0"
description = "Thin-wrapper around the mock package for easier use with pytest"
optional = false
python-versions = ">=3.8"
groups = ["test"]
files = [
{file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
]
[package.dependencies]
pytest = ">=6.2.5"
[package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@@ -2163,4 +2200,4 @@ graph = ["langchain-neo4j", "neo4j", "rank-bm25"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.9,<4.0"
content-hash = "6d54994c8286378de8dc348b8b11c57ce922dd8104a26cd63eb171563d1bbf5f"
content-hash = "07f2aee9c596c2d2470df085b92551b7b7e3c19cabe61ae5bee7505395601417"

View File

@@ -36,6 +36,8 @@ graph = ["langchain-neo4j", "neo4j", "rank-bm25"]
[tool.poetry.group.test.dependencies]
pytest = "^8.2.2"
pytest-mock = "^3.14.0"
pytest-asyncio = "^0.23.7"
[tool.poetry.group.dev.dependencies]
ruff = "^0.6.5"
@@ -49,4 +51,3 @@ build-backend = "poetry.core.masonry.api"
[tool.ruff]
line-length = 120
exclude = ["embedchain/"]

135
tests/memory/test_main.py Normal file
View File

@@ -0,0 +1,135 @@
import logging
from unittest.mock import MagicMock
import pytest
from mem0.memory.main import AsyncMemory, Memory
def _setup_mocks(mocker):
"""Helper to setup common mocks for both sync and async fixtures"""
mock_embedder = mocker.MagicMock()
mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3]
mocker.patch('mem0.utils.factory.EmbedderFactory.create', mock_embedder)
mock_vector_store = mocker.MagicMock()
mock_vector_store.return_value.search.return_value = []
mocker.patch('mem0.utils.factory.VectorStoreFactory.create',
side_effect=[mock_vector_store.return_value, mocker.MagicMock()])
mock_llm = mocker.MagicMock()
mocker.patch('mem0.utils.factory.LlmFactory.create', mock_llm)
mocker.patch('mem0.memory.storage.SQLiteManager', mocker.MagicMock())
return mock_llm, mock_vector_store
class TestAddToVectorStoreErrors:
@pytest.fixture
def mock_memory(self, mocker):
"""Fixture that returns a Memory instance with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker)
memory = Memory()
memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1"
return memory
def test_empty_llm_response_fact_extraction(self, mock_memory, caplog):
"""Test empty response from LLM during fact extraction"""
# Setup
mock_memory.llm.generate_response.return_value = ""
# Execute
with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
)
# Verify
assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed
assert "Error in new_retrieved_facts" in caplog.text
def test_empty_llm_response_memory_actions(self, mock_memory, caplog):
"""Test empty response from LLM during memory actions"""
# Setup
# First call returns valid JSON, second call returns empty string
mock_memory.llm.generate_response.side_effect = [
'{"facts": ["test fact"]}',
""
]
# Execute
with caplog.at_level(logging.ERROR):
result = mock_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
)
# Verify
assert mock_memory.llm.generate_response.call_count == 2
assert result == [] # Should return empty list when no memories processed
assert "Invalid JSON response" in caplog.text
@pytest.mark.asyncio
class TestAsyncAddToVectorStoreErrors:
@pytest.fixture
def mock_async_memory(self, mocker):
"""Fixture for AsyncMemory with mocker-based mocks"""
mock_llm, _ = _setup_mocks(mocker)
memory = AsyncMemory()
memory.config = mocker.MagicMock()
memory.config.custom_fact_extraction_prompt = None
memory.config.custom_update_memory_prompt = None
memory.api_version = "v1.1"
return memory
@pytest.mark.asyncio
async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
mock_async_memory.llm.generate_response.return_value = ""
with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
)
assert result == []
assert "Error in new_retrieved_facts" in caplog.text
@pytest.mark.asyncio
async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker):
"""Test empty response in AsyncMemory._add_to_vector_store"""
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
mock_async_memory.llm.generate_response.side_effect = [
'{"facts": ["test fact"]}',
""
]
with caplog.at_level(logging.ERROR):
result = await mock_async_memory._add_to_vector_store(
messages=[{"role": "user", "content": "test"}],
metadata={},
filters={},
infer=True
)
assert result == []
assert "Invalid JSON response" in caplog.text