diff --git a/mem0/memory/main.py b/mem0/memory/main.py index f6a9423c..725fb12a 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -542,7 +542,7 @@ class Memory(MemoryBase): try: existing_memory = self.vector_store.get(vector_id=memory_id) - except Exception as e: + except Exception: raise ValueError(f"Error getting memory with ID {memory_id}. Please provide a valid 'memory_id'") prev_value = existing_memory.payload.get("data") @@ -591,6 +591,9 @@ class Memory(MemoryBase): """ logger.warning("Resetting all memories") self.vector_store.delete_col() + self.vector_store = VectorStoreFactory.create( + self.config.vector_store.provider, self.config.vector_store.config + ) self.db.reset() capture_event("mem0.reset", self) diff --git a/mem0/memory/storage.py b/mem0/memory/storage.py index 87a256dc..8f10a5b1 100644 --- a/mem0/memory/storage.py +++ b/mem0/memory/storage.py @@ -140,3 +140,4 @@ class SQLiteManager: def reset(self): with self.connection: self.connection.execute("DROP TABLE IF EXISTS history") + self._create_history_table() diff --git a/tests/test_main.py b/tests/test_main.py index ed348ec3..66bc6d55 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,6 +5,7 @@ import pytest from mem0.configs.base import MemoryConfig from mem0.memory.main import Memory +from mem0.utils.factory import VectorStoreFactory @pytest.fixture(autouse=True) @@ -171,13 +172,19 @@ def test_delete_all(memory_instance, version, enable_graph): def test_reset(memory_instance): memory_instance.vector_store.delete_col = Mock() + # persisting vector store to make sure previous collection is deleted + initial_vector_store = memory_instance.vector_store memory_instance.db.reset = Mock() - memory_instance.reset() - - memory_instance.vector_store.delete_col.assert_called_once() - memory_instance.db.reset.assert_called_once() + with patch.object(VectorStoreFactory, "create", return_value=Mock()) as mock_create: + + memory_instance.reset() + initial_vector_store.delete_col.assert_called_once() + memory_instance.db.reset.assert_called_once() + mock_create.assert_called_once_with( + memory_instance.config.vector_store.provider, memory_instance.config.vector_store.config + ) @pytest.mark.parametrize( "version, enable_graph, expected_result",