Reset function for VectorDBs (#2584)

This commit is contained in:
Dev Khant
2025-04-25 00:01:53 +05:30
committed by GitHub
parent ff6ae478f1
commit 64c3d34deb
18 changed files with 224 additions and 20 deletions

View File

@@ -1,10 +1,10 @@
import os
import asyncio
import concurrent
import gc
import hashlib
import json
import logging
import os
import uuid
import warnings
from datetime import datetime
@@ -20,7 +20,7 @@ from mem0.configs.prompts import (
get_update_memory_messages,
)
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config, mem0_dir
from mem0.memory.setup import mem0_dir, setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import (
@@ -67,7 +67,7 @@ class Memory(MemoryBase):
self.graph = MemoryGraph(self.config)
self.enable_graph = True
self.config.vector_store.config.collection_name = "mem0_migrations"
self.config.vector_store.config.collection_name = "mem0-migrations"
if self.config.vector_store.provider in ["faiss", "qdrant"]:
provider_path = f"migrations_{self.config.vector_store.provider}"
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
@@ -765,13 +765,6 @@ class Memory(MemoryBase):
Recreates the vector store with a new client
"""
logger.warning("Resetting all memories")
self.vector_store.delete_col()
gc.collect()
# Close the client if it has a close method
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
self.vector_store.client.close()
# Close the old connection if possible
if hasattr(self.db, 'connection') and self.db.connection:
@@ -780,10 +773,14 @@ class Memory(MemoryBase):
self.db = SQLiteManager(self.config.history_db_path)
# Create a new vector store with the same configuration
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
if hasattr(self.vector_store, 'reset'):
self.vector_store = VectorStoreFactory.reset(self.vector_store)
else:
logger.warning("Vector store does not support reset. Skipping.")
self.vector_store.delete_col()
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
capture_event("mem0.reset", self, {"sync_type": "sync"})
def chat(self, query):

View File

@@ -97,3 +97,9 @@ class VectorStoreFactory:
return vector_store_instance(**config)
else:
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
@classmethod
def reset(cls, instance):
instance.reset()
return instance

View File

@@ -65,6 +65,8 @@ class AzureAISearch(VectorStoreBase):
hybrid_search (bool): Whether to use hybrid search. Default is False.
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
"""
self.service_name = service_name
self.api_key = api_key
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
@@ -341,3 +343,38 @@ class AzureAISearch(VectorStoreBase):
"""Close the search client when the object is deleted."""
self.search_client.close()
self.index_client.close()
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.index_name}...")
try:
# Close the existing clients
self.search_client.close()
self.index_client.close()
# Delete the collection
self.delete_col()
# Reinitialize the clients
service_endpoint = f"https://{self.service_name}.search.windows.net"
self.search_client = SearchClient(
endpoint=service_endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(self.api_key),
)
self.index_client = SearchIndexClient(
endpoint=service_endpoint,
credential=AzureKeyCredential(self.api_key),
)
# Add user agent
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
# Create the collection
self.create_col()
except Exception as e:
logger.error(f"Error resetting index {self.index_name}: {e}")
raise

View File

@@ -51,3 +51,8 @@ class VectorStoreBase(ABC):
def list(self, filters=None, limit=None):
"""List all memories."""
pass
@abstractmethod
def reset(self):
"""Reset by delete the collection and recreate it."""
pass

View File

@@ -221,3 +221,9 @@ class ChromaDB(VectorStoreBase):
"""
results = self.collection.get(where=filters, limit=limit)
return [self._parse_output(results)]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.collection = self.create_col(self.collection_name)

View File

@@ -222,3 +222,9 @@ class ElasticsearchDB(VectorStoreBase):
)
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_index()

View File

@@ -465,3 +465,9 @@ class FAISS(VectorStoreBase):
break
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name)

View File

@@ -1,3 +1,4 @@
import logging
from typing import Dict, List, Optional
from pydantic import BaseModel
@@ -11,6 +12,7 @@ except ImportError:
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str] # memory id
@@ -135,7 +137,14 @@ class Langchain(VectorStoreBase):
"""
Delete a collection.
"""
self.client.delete(ids=None)
logger.warning("Deleting collection")
if hasattr(self.client, "delete_collection"):
self.client.delete_collection()
elif hasattr(self.client, "reset_collection"):
self.client.reset_collection()
else:
# Fallback to the generic delete method
self.client.delete(ids=None)
def col_info(self):
"""
@@ -147,5 +156,27 @@ class Langchain(VectorStoreBase):
"""
List all vectors in a collection.
"""
# This would require implementation-specific access to the underlying store
raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores")
try:
if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
# Convert mem0 filters to Chroma where clause if needed
where_clause = None
if filters and "user_id" in filters:
where_clause = {"user_id": filters["user_id"]}
result = self.client._collection.get(
where=where_clause,
limit=limit
)
# Convert the result to the expected format
if result and isinstance(result, dict):
return [self._parse_output(result)]
return []
except Exception as e:
logger.error(f"Error listing vectors from Chroma: {e}")
return []
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting collection: {self.collection_name}")
self.delete_col()

View File

@@ -237,3 +237,9 @@ class MilvusDB(VectorStoreBase):
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
memories.append(obj)
return [memories]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)

View File

@@ -199,3 +199,9 @@ class OpenSearchDB(VectorStoreBase):
for hit in response["hits"]["hits"]
]
]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_index()

View File

@@ -286,3 +286,9 @@ class PGVector(VectorStoreBase):
self.cur.close()
if hasattr(self, "conn"):
self.conn.close()
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims)

View File

@@ -369,5 +369,6 @@ class PineconeDB(VectorStoreBase):
"""
Reset the index by deleting and recreating it.
"""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims, self.metric)

View File

@@ -67,6 +67,7 @@ class Qdrant(VectorStoreBase):
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.on_disk = on_disk
self.create_col(embedding_model_dims, on_disk)
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
@@ -231,3 +232,9 @@ class Qdrant(VectorStoreBase):
with_vectors=False,
)
return result
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims, self.on_disk)

View File

@@ -75,9 +75,48 @@ class RedisDB(VectorStoreBase):
self.index.set_client(self.client)
self.index.create(overwrite=True)
# TODO: Implement multiindex support.
def create_col(self, name, vector_size, distance):
raise NotImplementedError("Collection/Index creation not supported yet.")
def create_col(self, name=None, vector_size=None, distance=None):
"""
Create a new collection (index) in Redis.
Args:
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
Returns:
The created index object.
"""
# Use provided parameters or fall back to instance attributes
collection_name = name or self.schema['index']['name']
embedding_dims = vector_size or self.embedding_model_dims
distance_metric = distance or "cosine"
# Create a new schema with the specified parameters
index_schema = {
"name": collection_name,
"prefix": f"mem0:{collection_name}",
}
# Copy the default fields and update the vector field with the specified dimensions
fields = DEFAULT_FIELDS.copy()
fields[-1]["attrs"]["dims"] = embedding_dims
fields[-1]["attrs"]["distance_metric"] = distance_metric
# Create the schema
schema = {"index": index_schema, "fields": fields}
# Create the index
index = SearchIndex.from_dict(schema)
index.set_client(self.client)
index.create(overwrite=True)
# Update instance attributes if creating a new collection
if name:
self.schema = schema
self.index = index
return index
def insert(self, vectors: list, payloads: list = None, ids: list = None):
data = []
@@ -194,6 +233,25 @@ class RedisDB(VectorStoreBase):
def col_info(self, name):
return self.index.info()
def reset(self):
"""
Reset the index by deleting and recreating it.
"""
collection_name = self.schema['index']['name']
logger.warning(f"Resetting index {collection_name}...")
self.delete_col()
self.index = SearchIndex.from_dict(self.schema)
self.index.set_client(self.client)
self.index.create(overwrite=True)
#or use
#self.create_col(collection_name, self.embedding_model_dims)
# Recreate the index with the same parameters
self.create_col(collection_name, self.embedding_model_dims)
def list(self, filters: dict = None, limit: int = None) -> list:
"""
List all recent created memories from the vector store.

View File

@@ -229,3 +229,9 @@ class Supabase(VectorStoreBase):
records = self.collection.fetch(ids=ids)
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims)

View File

@@ -285,3 +285,10 @@ class UpstashVector(VectorStoreBase):
- Per-namespace vector and pending vector counts
"""
return self.client.info()
def reset(self):
"""
Reset the Upstash Vector index.
"""
self.delete_col()

View File

@@ -620,3 +620,10 @@ class GoogleMatchingEngine(VectorStoreBase):
logger.debug("Starting similarity search")
docs_and_scores = self.similarity_search_with_score(query, k, filter)
return [doc for doc, _ in docs_and_scores]
def reset(self):
"""
Reset the Google Matching Engine index.
"""
logger.warning("Reset operation is not supported for Google Matching Engine")
pass

View File

@@ -308,3 +308,9 @@ class Weaviate(VectorStoreBase):
payload["id"] = str(obj.uuid).split("'")[0]
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
return [results]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col()