Fix/new memories wrong type (#2635)
This commit is contained in:
@@ -66,6 +66,8 @@ class Memory(MemoryBase):
|
||||
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
else:
|
||||
self.graph = None
|
||||
|
||||
self.config.vector_store.config.collection_name = "mem0migrations"
|
||||
if self.config.vector_store.provider in ["faiss", "qdrant"]:
|
||||
@@ -263,20 +265,20 @@ class Memory(MemoryBase):
|
||||
)
|
||||
|
||||
try:
|
||||
new_memories_with_actions = self.llm.generate_response(
|
||||
response: str = self.llm.generate_response(
|
||||
messages=[{"role": "user", "content": function_calling_prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
new_memories_with_actions = []
|
||||
logging.error(f"Error in new memory actions response: {e}")
|
||||
response = ""
|
||||
|
||||
try:
|
||||
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
|
||||
new_memories_with_actions = json.loads(new_memories_with_actions)
|
||||
response = remove_code_blocks(response)
|
||||
new_memories_with_actions = json.loads(response)
|
||||
except Exception as e:
|
||||
logging.error(f"Invalid JSON response: {e}")
|
||||
new_memories_with_actions = []
|
||||
new_memories_with_actions = {}
|
||||
|
||||
returned_memories = []
|
||||
try:
|
||||
@@ -767,13 +769,13 @@ class Memory(MemoryBase):
|
||||
logger.warning("Resetting all memories")
|
||||
|
||||
# Close the old connection if possible
|
||||
if hasattr(self.db, 'connection') and self.db.connection:
|
||||
if hasattr(self.db, "connection") and self.db.connection:
|
||||
self.db.connection.execute("DROP TABLE IF EXISTS history")
|
||||
self.db.connection.close()
|
||||
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
|
||||
if hasattr(self.vector_store, 'reset'):
|
||||
if hasattr(self.vector_store, "reset"):
|
||||
self.vector_store = VectorStoreFactory.reset(self.vector_store)
|
||||
else:
|
||||
logger.warning("Vector store does not support reset. Skipping.")
|
||||
@@ -811,6 +813,8 @@ class AsyncMemory(MemoryBase):
|
||||
|
||||
self.graph = MemoryGraph(self.config)
|
||||
self.enable_graph = True
|
||||
else:
|
||||
self.graph = None
|
||||
|
||||
capture_event("mem0.init", self, {"sync_type": "async"})
|
||||
|
||||
@@ -1007,21 +1011,21 @@ class AsyncMemory(MemoryBase):
|
||||
)
|
||||
|
||||
try:
|
||||
new_memories_with_actions = await asyncio.to_thread(
|
||||
response: str = await asyncio.to_thread(
|
||||
self.llm.generate_response,
|
||||
messages=[{"role": "user", "content": function_calling_prompt}],
|
||||
response_format={"type": "json_object"},
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
new_memories_with_actions = []
|
||||
logging.error(f"Error in new memory actions response: {e}")
|
||||
response = ""
|
||||
|
||||
try:
|
||||
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
|
||||
new_memories_with_actions = json.loads(new_memories_with_actions)
|
||||
response = remove_code_blocks(response)
|
||||
new_memories_with_actions = json.loads(response)
|
||||
except Exception as e:
|
||||
logging.error(f"Invalid JSON response: {e}")
|
||||
new_memories_with_actions = []
|
||||
new_memories_with_actions = {}
|
||||
|
||||
returned_memories = []
|
||||
try:
|
||||
@@ -1092,7 +1096,9 @@ class AsyncMemory(MemoryBase):
|
||||
except Exception as e:
|
||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||
|
||||
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"})
|
||||
capture_event(
|
||||
"mem0.add", self, {"version": self.api_version, "keys": list(filters.keys()), "sync_type": "async"}
|
||||
)
|
||||
|
||||
return returned_memories
|
||||
|
||||
@@ -1547,10 +1553,10 @@ class AsyncMemory(MemoryBase):
|
||||
|
||||
gc.collect()
|
||||
|
||||
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
|
||||
if hasattr(self.vector_store, "client") and hasattr(self.vector_store.client, "close"):
|
||||
await asyncio.to_thread(self.vector_store.client.close)
|
||||
|
||||
if hasattr(self.db, 'connection') and self.db.connection:
|
||||
if hasattr(self.db, "connection") and self.db.connection:
|
||||
await asyncio.to_thread(lambda: self.db.connection.execute("DROP TABLE IF EXISTS history"))
|
||||
await asyncio.to_thread(self.db.connection.close)
|
||||
|
||||
|
||||
39
poetry.lock
generated
39
poetry.lock
generated
@@ -1544,6 +1544,43 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.23.8"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["test"]
|
||||
files = [
|
||||
{file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"},
|
||||
{file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0,<9"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-mock"
|
||||
version = "3.14.0"
|
||||
description = "Thin-wrapper around the mock package for easier use with pytest"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["test"]
|
||||
files = [
|
||||
{file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"},
|
||||
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=6.2.5"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "pytest-asyncio", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@@ -2163,4 +2200,4 @@ graph = ["langchain-neo4j", "neo4j", "rank-bm25"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "6d54994c8286378de8dc348b8b11c57ce922dd8104a26cd63eb171563d1bbf5f"
|
||||
content-hash = "07f2aee9c596c2d2470df085b92551b7b7e3c19cabe61ae5bee7505395601417"
|
||||
|
||||
@@ -36,6 +36,8 @@ graph = ["langchain-neo4j", "neo4j", "rank-bm25"]
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^8.2.2"
|
||||
pytest-mock = "^3.14.0"
|
||||
pytest-asyncio = "^0.23.7"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "^0.6.5"
|
||||
@@ -49,4 +51,3 @@ build-backend = "poetry.core.masonry.api"
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
exclude = ["embedchain/"]
|
||||
|
||||
|
||||
135
tests/memory/test_main.py
Normal file
135
tests/memory/test_main.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.memory.main import AsyncMemory, Memory
|
||||
|
||||
|
||||
def _setup_mocks(mocker):
|
||||
"""Helper to setup common mocks for both sync and async fixtures"""
|
||||
mock_embedder = mocker.MagicMock()
|
||||
mock_embedder.return_value.embed.return_value = [0.1, 0.2, 0.3]
|
||||
mocker.patch('mem0.utils.factory.EmbedderFactory.create', mock_embedder)
|
||||
|
||||
mock_vector_store = mocker.MagicMock()
|
||||
mock_vector_store.return_value.search.return_value = []
|
||||
mocker.patch('mem0.utils.factory.VectorStoreFactory.create',
|
||||
side_effect=[mock_vector_store.return_value, mocker.MagicMock()])
|
||||
|
||||
mock_llm = mocker.MagicMock()
|
||||
mocker.patch('mem0.utils.factory.LlmFactory.create', mock_llm)
|
||||
|
||||
mocker.patch('mem0.memory.storage.SQLiteManager', mocker.MagicMock())
|
||||
|
||||
return mock_llm, mock_vector_store
|
||||
|
||||
|
||||
class TestAddToVectorStoreErrors:
|
||||
@pytest.fixture
|
||||
def mock_memory(self, mocker):
|
||||
"""Fixture that returns a Memory instance with mocker-based mocks"""
|
||||
mock_llm, _ = _setup_mocks(mocker)
|
||||
|
||||
memory = Memory()
|
||||
memory.config = mocker.MagicMock()
|
||||
memory.config.custom_fact_extraction_prompt = None
|
||||
memory.config.custom_update_memory_prompt = None
|
||||
memory.api_version = "v1.1"
|
||||
|
||||
return memory
|
||||
|
||||
def test_empty_llm_response_fact_extraction(self, mock_memory, caplog):
|
||||
"""Test empty response from LLM during fact extraction"""
|
||||
# Setup
|
||||
mock_memory.llm.generate_response.return_value = ""
|
||||
|
||||
# Execute
|
||||
with caplog.at_level(logging.ERROR):
|
||||
result = mock_memory._add_to_vector_store(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
metadata={},
|
||||
filters={},
|
||||
infer=True
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert mock_memory.llm.generate_response.call_count == 2
|
||||
assert result == [] # Should return empty list when no memories processed
|
||||
assert "Error in new_retrieved_facts" in caplog.text
|
||||
|
||||
def test_empty_llm_response_memory_actions(self, mock_memory, caplog):
|
||||
"""Test empty response from LLM during memory actions"""
|
||||
# Setup
|
||||
# First call returns valid JSON, second call returns empty string
|
||||
mock_memory.llm.generate_response.side_effect = [
|
||||
'{"facts": ["test fact"]}',
|
||||
""
|
||||
]
|
||||
|
||||
# Execute
|
||||
with caplog.at_level(logging.ERROR):
|
||||
result = mock_memory._add_to_vector_store(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
metadata={},
|
||||
filters={},
|
||||
infer=True
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert mock_memory.llm.generate_response.call_count == 2
|
||||
assert result == [] # Should return empty list when no memories processed
|
||||
assert "Invalid JSON response" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAsyncAddToVectorStoreErrors:
|
||||
@pytest.fixture
|
||||
def mock_async_memory(self, mocker):
|
||||
"""Fixture for AsyncMemory with mocker-based mocks"""
|
||||
mock_llm, _ = _setup_mocks(mocker)
|
||||
|
||||
memory = AsyncMemory()
|
||||
memory.config = mocker.MagicMock()
|
||||
memory.config.custom_fact_extraction_prompt = None
|
||||
memory.config.custom_update_memory_prompt = None
|
||||
memory.api_version = "v1.1"
|
||||
|
||||
return memory
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_empty_llm_response_fact_extraction(self, mock_async_memory, caplog, mocker):
|
||||
"""Test empty response in AsyncMemory._add_to_vector_store"""
|
||||
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
|
||||
mock_async_memory.llm.generate_response.return_value = ""
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
result = await mock_async_memory._add_to_vector_store(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
metadata={},
|
||||
filters={},
|
||||
infer=True
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert "Error in new_retrieved_facts" in caplog.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_empty_llm_response_memory_actions(self, mock_async_memory, caplog, mocker):
|
||||
"""Test empty response in AsyncMemory._add_to_vector_store"""
|
||||
mocker.patch('mem0.utils.factory.EmbedderFactory.create', return_value=MagicMock())
|
||||
mock_async_memory.llm.generate_response.side_effect = [
|
||||
'{"facts": ["test fact"]}',
|
||||
""
|
||||
]
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
result = await mock_async_memory._add_to_vector_store(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
metadata={},
|
||||
filters={},
|
||||
infer=True
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert "Invalid JSON response" in caplog.text
|
||||
Reference in New Issue
Block a user