Memory Reset (#2558)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import asyncio
|
||||
import concurrent
|
||||
import gc
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -755,14 +756,31 @@ class Memory(MemoryBase):
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the memory store.
|
||||
Reset the memory store by:
|
||||
Deletes the vector store collection
|
||||
Resets the database
|
||||
Recreates the vector store with a new client
|
||||
"""
|
||||
logger.warning("Resetting all memories")
|
||||
self.vector_store.delete_col()
|
||||
|
||||
gc.collect()
|
||||
|
||||
# Close the client if it has a close method
|
||||
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
|
||||
self.vector_store.client.close()
|
||||
|
||||
# Close the old connection if possible
|
||||
if hasattr(self.db, 'connection') and self.db.connection:
|
||||
self.db.connection.execute("DROP TABLE IF EXISTS history")
|
||||
self.db.connection.close()
|
||||
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
|
||||
# Create a new vector store with the same configuration
|
||||
self.vector_store = VectorStoreFactory.create(
|
||||
self.config.vector_store.provider, self.config.vector_store.config
|
||||
)
|
||||
self.db.reset()
|
||||
capture_event("mem0.reset", self, {"sync_type": "sync"})
|
||||
|
||||
def chat(self, query):
|
||||
@@ -1519,14 +1537,28 @@ class AsyncMemory(MemoryBase):
|
||||
|
||||
async def reset(self):
|
||||
"""
|
||||
Reset the memory store asynchronously.
|
||||
Reset the memory store asynchronously by:
|
||||
Deletes the vector store collection
|
||||
Resets the database
|
||||
Recreates the vector store with a new client
|
||||
"""
|
||||
logger.warning("Resetting all memories")
|
||||
await asyncio.to_thread(self.vector_store.delete_col)
|
||||
|
||||
gc.collect()
|
||||
|
||||
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
|
||||
await asyncio.to_thread(self.vector_store.client.close)
|
||||
|
||||
if hasattr(self.db, 'connection') and self.db.connection:
|
||||
await asyncio.to_thread(lambda: self.db.connection.execute("DROP TABLE IF EXISTS history"))
|
||||
await asyncio.to_thread(self.db.connection.close)
|
||||
|
||||
self.db = SQLiteManager(self.config.history_db_path)
|
||||
|
||||
self.vector_store = VectorStoreFactory.create(
|
||||
self.config.vector_store.provider, self.config.vector_store.config
|
||||
)
|
||||
await asyncio.to_thread(self.db.reset)
|
||||
capture_event("mem0.reset", self, {"sync_type": "async"})
|
||||
|
||||
async def chat(self, query):
|
||||
|
||||
@@ -142,9 +142,3 @@ class SQLiteManager:
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
with self._lock:
|
||||
with self.connection:
|
||||
self.connection.execute("DROP TABLE IF EXISTS history")
|
||||
self._create_history_table()
|
||||
|
||||
@@ -195,21 +195,6 @@ def test_delete_all(memory_instance, version, enable_graph):
|
||||
assert result["message"] == "Memories deleted successfully!"
|
||||
|
||||
|
||||
def test_reset(memory_instance):
|
||||
memory_instance.vector_store.delete_col = Mock()
|
||||
# persisting vector store to make sure previous collection is deleted
|
||||
initial_vector_store = memory_instance.vector_store
|
||||
memory_instance.db.reset = Mock()
|
||||
|
||||
with patch.object(VectorStoreFactory, "create", return_value=Mock()) as mock_create:
|
||||
memory_instance.reset()
|
||||
|
||||
initial_vector_store.delete_col.assert_called_once()
|
||||
memory_instance.db.reset.assert_called_once()
|
||||
mock_create.assert_called_once_with(
|
||||
memory_instance.config.vector_store.provider, memory_instance.config.vector_store.config
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version, enable_graph, expected_result",
|
||||
|
||||
Reference in New Issue
Block a user