Memory Reset (#2558)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent
|
import concurrent
|
||||||
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -755,14 +756,31 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Reset the memory store.
|
Reset the memory store by:
|
||||||
|
Deletes the vector store collection
|
||||||
|
Resets the database
|
||||||
|
Recreates the vector store with a new client
|
||||||
"""
|
"""
|
||||||
logger.warning("Resetting all memories")
|
logger.warning("Resetting all memories")
|
||||||
self.vector_store.delete_col()
|
self.vector_store.delete_col()
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Close the client if it has a close method
|
||||||
|
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
|
||||||
|
self.vector_store.client.close()
|
||||||
|
|
||||||
|
# Close the old connection if possible
|
||||||
|
if hasattr(self.db, 'connection') and self.db.connection:
|
||||||
|
self.db.connection.execute("DROP TABLE IF EXISTS history")
|
||||||
|
self.db.connection.close()
|
||||||
|
|
||||||
|
self.db = SQLiteManager(self.config.history_db_path)
|
||||||
|
|
||||||
|
# Create a new vector store with the same configuration
|
||||||
self.vector_store = VectorStoreFactory.create(
|
self.vector_store = VectorStoreFactory.create(
|
||||||
self.config.vector_store.provider, self.config.vector_store.config
|
self.config.vector_store.provider, self.config.vector_store.config
|
||||||
)
|
)
|
||||||
self.db.reset()
|
|
||||||
capture_event("mem0.reset", self, {"sync_type": "sync"})
|
capture_event("mem0.reset", self, {"sync_type": "sync"})
|
||||||
|
|
||||||
def chat(self, query):
|
def chat(self, query):
|
||||||
@@ -1519,14 +1537,28 @@ class AsyncMemory(MemoryBase):
|
|||||||
|
|
||||||
async def reset(self):
|
async def reset(self):
|
||||||
"""
|
"""
|
||||||
Reset the memory store asynchronously.
|
Reset the memory store asynchronously by:
|
||||||
|
Deletes the vector store collection
|
||||||
|
Resets the database
|
||||||
|
Recreates the vector store with a new client
|
||||||
"""
|
"""
|
||||||
logger.warning("Resetting all memories")
|
logger.warning("Resetting all memories")
|
||||||
await asyncio.to_thread(self.vector_store.delete_col)
|
await asyncio.to_thread(self.vector_store.delete_col)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
|
||||||
|
await asyncio.to_thread(self.vector_store.client.close)
|
||||||
|
|
||||||
|
if hasattr(self.db, 'connection') and self.db.connection:
|
||||||
|
await asyncio.to_thread(lambda: self.db.connection.execute("DROP TABLE IF EXISTS history"))
|
||||||
|
await asyncio.to_thread(self.db.connection.close)
|
||||||
|
|
||||||
|
self.db = SQLiteManager(self.config.history_db_path)
|
||||||
|
|
||||||
self.vector_store = VectorStoreFactory.create(
|
self.vector_store = VectorStoreFactory.create(
|
||||||
self.config.vector_store.provider, self.config.vector_store.config
|
self.config.vector_store.provider, self.config.vector_store.config
|
||||||
)
|
)
|
||||||
await asyncio.to_thread(self.db.reset)
|
|
||||||
capture_event("mem0.reset", self, {"sync_type": "async"})
|
capture_event("mem0.reset", self, {"sync_type": "async"})
|
||||||
|
|
||||||
async def chat(self, query):
|
async def chat(self, query):
|
||||||
|
|||||||
@@ -142,9 +142,3 @@ class SQLiteManager:
|
|||||||
}
|
}
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
with self._lock:
|
|
||||||
with self.connection:
|
|
||||||
self.connection.execute("DROP TABLE IF EXISTS history")
|
|
||||||
self._create_history_table()
|
|
||||||
|
|||||||
@@ -195,21 +195,6 @@ def test_delete_all(memory_instance, version, enable_graph):
|
|||||||
assert result["message"] == "Memories deleted successfully!"
|
assert result["message"] == "Memories deleted successfully!"
|
||||||
|
|
||||||
|
|
||||||
def test_reset(memory_instance):
|
|
||||||
memory_instance.vector_store.delete_col = Mock()
|
|
||||||
# persisting vector store to make sure previous collection is deleted
|
|
||||||
initial_vector_store = memory_instance.vector_store
|
|
||||||
memory_instance.db.reset = Mock()
|
|
||||||
|
|
||||||
with patch.object(VectorStoreFactory, "create", return_value=Mock()) as mock_create:
|
|
||||||
memory_instance.reset()
|
|
||||||
|
|
||||||
initial_vector_store.delete_col.assert_called_once()
|
|
||||||
memory_instance.db.reset.assert_called_once()
|
|
||||||
mock_create.assert_called_once_with(
|
|
||||||
memory_instance.config.vector_store.provider, memory_instance.config.vector_store.config
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"version, enable_graph, expected_result",
|
"version, enable_graph, expected_result",
|
||||||
|
|||||||
Reference in New Issue
Block a user