Reset function for VectorDBs (#2584)

This commit is contained in:
Dev Khant
2025-04-25 00:01:53 +05:30
committed by GitHub
parent ff6ae478f1
commit 64c3d34deb
18 changed files with 224 additions and 20 deletions

View File

@@ -1,10 +1,10 @@
import os
import asyncio import asyncio
import concurrent import concurrent
import gc import gc
import hashlib import hashlib
import json import json
import logging import logging
import os
import uuid import uuid
import warnings import warnings
from datetime import datetime from datetime import datetime
@@ -20,7 +20,7 @@ from mem0.configs.prompts import (
get_update_memory_messages, get_update_memory_messages,
) )
from mem0.memory.base import MemoryBase from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config, mem0_dir from mem0.memory.setup import mem0_dir, setup_config
from mem0.memory.storage import SQLiteManager from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event from mem0.memory.telemetry import capture_event
from mem0.memory.utils import ( from mem0.memory.utils import (
@@ -67,7 +67,7 @@ class Memory(MemoryBase):
self.graph = MemoryGraph(self.config) self.graph = MemoryGraph(self.config)
self.enable_graph = True self.enable_graph = True
self.config.vector_store.config.collection_name = "mem0_migrations" self.config.vector_store.config.collection_name = "mem0-migrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]: if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}" provider_path = f"migrations_{self.config.vector_store.provider}"
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path) self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
@@ -765,13 +765,6 @@ class Memory(MemoryBase):
Recreates the vector store with a new client Recreates the vector store with a new client
""" """
logger.warning("Resetting all memories") logger.warning("Resetting all memories")
self.vector_store.delete_col()
gc.collect()
# Close the client if it has a close method
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
self.vector_store.client.close()
# Close the old connection if possible # Close the old connection if possible
if hasattr(self.db, 'connection') and self.db.connection: if hasattr(self.db, 'connection') and self.db.connection:
@@ -780,10 +773,14 @@ class Memory(MemoryBase):
self.db = SQLiteManager(self.config.history_db_path) self.db = SQLiteManager(self.config.history_db_path)
# Create a new vector store with the same configuration if hasattr(self.vector_store, 'reset'):
self.vector_store = VectorStoreFactory.create( self.vector_store = VectorStoreFactory.reset(self.vector_store)
self.config.vector_store.provider, self.config.vector_store.config else:
) logger.warning("Vector store does not support reset. Skipping.")
self.vector_store.delete_col()
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
capture_event("mem0.reset", self, {"sync_type": "sync"}) capture_event("mem0.reset", self, {"sync_type": "sync"})
def chat(self, query): def chat(self, query):

View File

@@ -97,3 +97,9 @@ class VectorStoreFactory:
return vector_store_instance(**config) return vector_store_instance(**config)
else: else:
raise ValueError(f"Unsupported VectorStore provider: {provider_name}") raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
@classmethod
def reset(cls, instance):
instance.reset()
return instance

View File

@@ -65,6 +65,8 @@ class AzureAISearch(VectorStoreBase):
hybrid_search (bool): Whether to use hybrid search. Default is False. hybrid_search (bool): Whether to use hybrid search. Default is False.
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter". vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
""" """
self.service_name = service_name
self.api_key = api_key
self.index_name = collection_name self.index_name = collection_name
self.collection_name = collection_name self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims self.embedding_model_dims = embedding_model_dims
@@ -341,3 +343,38 @@ class AzureAISearch(VectorStoreBase):
"""Close the search client when the object is deleted.""" """Close the search client when the object is deleted."""
self.search_client.close() self.search_client.close()
self.index_client.close() self.index_client.close()
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.index_name}...")
try:
# Close the existing clients
self.search_client.close()
self.index_client.close()
# Delete the collection
self.delete_col()
# Reinitialize the clients
service_endpoint = f"https://{self.service_name}.search.windows.net"
self.search_client = SearchClient(
endpoint=service_endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(self.api_key),
)
self.index_client = SearchIndexClient(
endpoint=service_endpoint,
credential=AzureKeyCredential(self.api_key),
)
# Add user agent
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
# Create the collection
self.create_col()
except Exception as e:
logger.error(f"Error resetting index {self.index_name}: {e}")
raise

View File

@@ -51,3 +51,8 @@ class VectorStoreBase(ABC):
def list(self, filters=None, limit=None): def list(self, filters=None, limit=None):
"""List all memories.""" """List all memories."""
pass pass
@abstractmethod
def reset(self):
"""Reset by delete the collection and recreate it."""
pass

View File

@@ -221,3 +221,9 @@ class ChromaDB(VectorStoreBase):
""" """
results = self.collection.get(where=filters, limit=limit) results = self.collection.get(where=filters, limit=limit)
return [self._parse_output(results)] return [self._parse_output(results)]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.collection = self.create_col(self.collection_name)

View File

@@ -222,3 +222,9 @@ class ElasticsearchDB(VectorStoreBase):
) )
return [results] return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_index()

View File

@@ -465,3 +465,9 @@ class FAISS(VectorStoreBase):
break break
return [results] return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name)

View File

@@ -1,3 +1,4 @@
import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -11,6 +12,7 @@ except ImportError:
from mem0.vector_stores.base import VectorStoreBase from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel): class OutputData(BaseModel):
id: Optional[str] # memory id id: Optional[str] # memory id
@@ -135,7 +137,14 @@ class Langchain(VectorStoreBase):
""" """
Delete a collection. Delete a collection.
""" """
self.client.delete(ids=None) logger.warning("Deleting collection")
if hasattr(self.client, "delete_collection"):
self.client.delete_collection()
elif hasattr(self.client, "reset_collection"):
self.client.reset_collection()
else:
# Fallback to the generic delete method
self.client.delete(ids=None)
def col_info(self): def col_info(self):
""" """
@@ -147,5 +156,27 @@ class Langchain(VectorStoreBase):
""" """
List all vectors in a collection. List all vectors in a collection.
""" """
# This would require implementation-specific access to the underlying store try:
raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores") if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
# Convert mem0 filters to Chroma where clause if needed
where_clause = None
if filters and "user_id" in filters:
where_clause = {"user_id": filters["user_id"]}
result = self.client._collection.get(
where=where_clause,
limit=limit
)
# Convert the result to the expected format
if result and isinstance(result, dict):
return [self._parse_output(result)]
return []
except Exception as e:
logger.error(f"Error listing vectors from Chroma: {e}")
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting collection: {self.collection_name}")
self.delete_col()

View File

@@ -237,3 +237,9 @@ class MilvusDB(VectorStoreBase):
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata")) obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj) memories.append(obj)
return [memories] return [memories]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)

View File

@@ -199,3 +199,9 @@ class OpenSearchDB(VectorStoreBase):
for hit in response["hits"]["hits"] for hit in response["hits"]["hits"]
] ]
] ]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_index()

View File

@@ -286,3 +286,9 @@ class PGVector(VectorStoreBase):
self.cur.close() self.cur.close()
if hasattr(self, "conn"): if hasattr(self, "conn"):
self.conn.close() self.conn.close()
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims)

View File

@@ -369,5 +369,6 @@ class PineconeDB(VectorStoreBase):
""" """
Reset the index by deleting and recreating it. Reset the index by deleting and recreating it.
""" """
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col() self.delete_col()
self.create_col(self.embedding_model_dims, self.metric) self.create_col(self.embedding_model_dims, self.metric)

View File

@@ -67,6 +67,7 @@ class Qdrant(VectorStoreBase):
self.collection_name = collection_name self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims self.embedding_model_dims = embedding_model_dims
self.on_disk = on_disk
self.create_col(embedding_model_dims, on_disk) self.create_col(embedding_model_dims, on_disk)
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE): def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
@@ -231,3 +232,9 @@ class Qdrant(VectorStoreBase):
with_vectors=False, with_vectors=False,
) )
return result return result
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims, self.on_disk)

View File

@@ -75,9 +75,48 @@ class RedisDB(VectorStoreBase):
self.index.set_client(self.client) self.index.set_client(self.client)
self.index.create(overwrite=True) self.index.create(overwrite=True)
# TODO: Implement multiindex support. def create_col(self, name=None, vector_size=None, distance=None):
def create_col(self, name, vector_size, distance): """
raise NotImplementedError("Collection/Index creation not supported yet.") Create a new collection (index) in Redis.
Args:
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
Returns:
The created index object.
"""
# Use provided parameters or fall back to instance attributes
collection_name = name or self.schema['index']['name']
embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "cosine"
# Create a new schema with the specified parameters
index_schema = {
"name": collection_name,
"prefix": f"mem0:{collection_name}",
}
# Copy the default fields and update the vector field with the specified dimensions
fields = DEFAULT_FIELDS.copy()
fields[-1]["attrs"]["dims"] = embedding_dims
fields[-1]["attrs"]["distance_metric"] = distance_metric
# Create the schema
schema = {"index": index_schema, "fields": fields}
# Create the index
index = SearchIndex.from_dict(schema)
index.set_client(self.client)
index.create(overwrite=True)
# Update instance attributes if creating a new collection
if name:
self.schema = schema
self.index = index
return index
def insert(self, vectors: list, payloads: list = None, ids: list = None): def insert(self, vectors: list, payloads: list = None, ids: list = None):
data = [] data = []
@@ -194,6 +233,25 @@ class RedisDB(VectorStoreBase):
def col_info(self, name): def col_info(self, name):
return self.index.info() return self.index.info()
def reset(self):
"""
Reset the index by deleting and recreating it.
"""
collection_name = self.schema['index']['name']
logger.warning(f"Resetting index {collection_name}...")
self.delete_col()
self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client)
self.index.create(overwrite=True)
#or use
#self.create_col(collection_name, self.embedding_model_dims)
# Recreate the index with the same parameters
self.create_col(collection_name, self.embedding_model_dims)
def list(self, filters: dict = None, limit: int = None) -> list: def list(self, filters: dict = None, limit: int = None) -> list:
""" """
List all recent created memories from the vector store. List all recent created memories from the vector store.

View File

@@ -229,3 +229,9 @@ class Supabase(VectorStoreBase):
records = self.collection.fetch(ids=ids) records = self.collection.fetch(ids=ids)
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]] return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims)

View File

@@ -285,3 +285,10 @@ class UpstashVector(VectorStoreBase):
- Per-namespace vector and pending vector counts - Per-namespace vector and pending vector counts
""" """
return self.client.info() return self.client.info()
def reset(self):
"""
Reset the Upstash Vector index.
"""
self.delete_col()

View File

@@ -620,3 +620,10 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Starting similarity search") logger.debug("Starting similarity search")
docs_and_scores = self.similarity_search_with_score(query, k, filter) docs_and_scores = self.similarity_search_with_score(query, k, filter)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def reset(self):
"""
Reset the Google Matching Engine index.
"""
logger.warning("Reset operation is not supported for Google Matching Engine")
pass

View File

@@ -308,3 +308,9 @@ class Weaviate(VectorStoreBase):
payload["id"] = str(obj.uuid).split("'")[0] payload["id"] = str(obj.uuid).split("'")[0]
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload)) results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
return [results] return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col()