Reset function for VectorDBs (#2584)
This commit is contained in:
@@ -65,6 +65,8 @@ class AzureAISearch(VectorStoreBase):
|
||||
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
||||
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.api_key = api_key
|
||||
self.index_name = collection_name
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
@@ -341,3 +343,38 @@ class AzureAISearch(VectorStoreBase):
|
||||
"""Close the search client when the object is deleted."""
|
||||
self.search_client.close()
|
||||
self.index_client.close()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.index_name}...")
|
||||
|
||||
try:
|
||||
# Close the existing clients
|
||||
self.search_client.close()
|
||||
self.index_client.close()
|
||||
|
||||
# Delete the collection
|
||||
self.delete_col()
|
||||
|
||||
# Reinitialize the clients
|
||||
service_endpoint = f"https://{self.service_name}.search.windows.net"
|
||||
self.search_client = SearchClient(
|
||||
endpoint=service_endpoint,
|
||||
index_name=self.index_name,
|
||||
credential=AzureKeyCredential(self.api_key),
|
||||
)
|
||||
self.index_client = SearchIndexClient(
|
||||
endpoint=service_endpoint,
|
||||
credential=AzureKeyCredential(self.api_key),
|
||||
)
|
||||
|
||||
# Add user agent
|
||||
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||
|
||||
# Create the collection
|
||||
self.create_col()
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting index {self.index_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -51,3 +51,8 @@ class VectorStoreBase(ABC):
|
||||
def list(self, filters=None, limit=None):
|
||||
"""List all memories."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""Reset by delete the collection and recreate it."""
|
||||
pass
|
||||
|
||||
@@ -221,3 +221,9 @@ class ChromaDB(VectorStoreBase):
|
||||
"""
|
||||
results = self.collection.get(where=filters, limit=limit)
|
||||
return [self._parse_output(results)]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.collection = self.create_col(self.collection_name)
|
||||
|
||||
@@ -222,3 +222,9 @@ class ElasticsearchDB(VectorStoreBase):
|
||||
)
|
||||
|
||||
return [results]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_index()
|
||||
|
||||
@@ -465,3 +465,9 @@ class FAISS(VectorStoreBase):
|
||||
break
|
||||
|
||||
return [results]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.collection_name)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -11,6 +12,7 @@ except ImportError:
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
@@ -135,7 +137,14 @@ class Langchain(VectorStoreBase):
|
||||
"""
|
||||
Delete a collection.
|
||||
"""
|
||||
self.client.delete(ids=None)
|
||||
logger.warning("Deleting collection")
|
||||
if hasattr(self.client, "delete_collection"):
|
||||
self.client.delete_collection()
|
||||
elif hasattr(self.client, "reset_collection"):
|
||||
self.client.reset_collection()
|
||||
else:
|
||||
# Fallback to the generic delete method
|
||||
self.client.delete(ids=None)
|
||||
|
||||
def col_info(self):
|
||||
"""
|
||||
@@ -147,5 +156,27 @@ class Langchain(VectorStoreBase):
|
||||
"""
|
||||
List all vectors in a collection.
|
||||
"""
|
||||
# This would require implementation-specific access to the underlying store
|
||||
raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores")
|
||||
try:
|
||||
if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
|
||||
# Convert mem0 filters to Chroma where clause if needed
|
||||
where_clause = None
|
||||
if filters and "user_id" in filters:
|
||||
where_clause = {"user_id": filters["user_id"]}
|
||||
|
||||
result = self.client._collection.get(
|
||||
where=where_clause,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Convert the result to the expected format
|
||||
if result and isinstance(result, dict):
|
||||
return [self._parse_output(result)]
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing vectors from Chroma: {e}")
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting collection: {self.collection_name}")
|
||||
self.delete_col()
|
||||
|
||||
@@ -237,3 +237,9 @@ class MilvusDB(VectorStoreBase):
|
||||
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||
memories.append(obj)
|
||||
return [memories]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)
|
||||
|
||||
@@ -199,3 +199,9 @@ class OpenSearchDB(VectorStoreBase):
|
||||
for hit in response["hits"]["hits"]
|
||||
]
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_index()
|
||||
|
||||
@@ -286,3 +286,9 @@ class PGVector(VectorStoreBase):
|
||||
self.cur.close()
|
||||
if hasattr(self, "conn"):
|
||||
self.conn.close()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims)
|
||||
|
||||
@@ -369,5 +369,6 @@ class PineconeDB(VectorStoreBase):
|
||||
"""
|
||||
Reset the index by deleting and recreating it.
|
||||
"""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims, self.metric)
|
||||
|
||||
@@ -67,6 +67,7 @@ class Qdrant(VectorStoreBase):
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.on_disk = on_disk
|
||||
self.create_col(embedding_model_dims, on_disk)
|
||||
|
||||
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
|
||||
@@ -231,3 +232,9 @@ class Qdrant(VectorStoreBase):
|
||||
with_vectors=False,
|
||||
)
|
||||
return result
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims, self.on_disk)
|
||||
|
||||
@@ -75,9 +75,48 @@ class RedisDB(VectorStoreBase):
|
||||
self.index.set_client(self.client)
|
||||
self.index.create(overwrite=True)
|
||||
|
||||
# TODO: Implement multiindex support.
|
||||
def create_col(self, name, vector_size, distance):
|
||||
raise NotImplementedError("Collection/Index creation not supported yet.")
|
||||
def create_col(self, name=None, vector_size=None, distance=None):
|
||||
"""
|
||||
Create a new collection (index) in Redis.
|
||||
|
||||
Args:
|
||||
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
|
||||
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
|
||||
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
|
||||
|
||||
Returns:
|
||||
The created index object.
|
||||
"""
|
||||
# Use provided parameters or fall back to instance attributes
|
||||
collection_name = name or self.schema['index']['name']
|
||||
embedding_dims = vector_size or self.embedding_model_dims
|
||||
distance_metric = distance or "cosine"
|
||||
|
||||
# Create a new schema with the specified parameters
|
||||
index_schema = {
|
||||
"name": collection_name,
|
||||
"prefix": f"mem0:{collection_name}",
|
||||
}
|
||||
|
||||
# Copy the default fields and update the vector field with the specified dimensions
|
||||
fields = DEFAULT_FIELDS.copy()
|
||||
fields[-1]["attrs"]["dims"] = embedding_dims
|
||||
fields[-1]["attrs"]["distance_metric"] = distance_metric
|
||||
|
||||
# Create the schema
|
||||
schema = {"index": index_schema, "fields": fields}
|
||||
|
||||
# Create the index
|
||||
index = SearchIndex.from_dict(schema)
|
||||
index.set_client(self.client)
|
||||
index.create(overwrite=True)
|
||||
|
||||
# Update instance attributes if creating a new collection
|
||||
if name:
|
||||
self.schema = schema
|
||||
self.index = index
|
||||
|
||||
return index
|
||||
|
||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||
data = []
|
||||
@@ -194,6 +233,25 @@ class RedisDB(VectorStoreBase):
|
||||
def col_info(self, name):
|
||||
return self.index.info()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the index by deleting and recreating it.
|
||||
"""
|
||||
collection_name = self.schema['index']['name']
|
||||
logger.warning(f"Resetting index {collection_name}...")
|
||||
self.delete_col()
|
||||
|
||||
self.index = SearchIndex.from_dict(self.schema)
|
||||
self.index.set_client(self.client)
|
||||
self.index.create(overwrite=True)
|
||||
|
||||
#or use
|
||||
#self.create_col(collection_name, self.embedding_model_dims)
|
||||
|
||||
|
||||
# Recreate the index with the same parameters
|
||||
self.create_col(collection_name, self.embedding_model_dims)
|
||||
|
||||
def list(self, filters: dict = None, limit: int = None) -> list:
|
||||
"""
|
||||
List all recent created memories from the vector store.
|
||||
|
||||
@@ -229,3 +229,9 @@ class Supabase(VectorStoreBase):
|
||||
records = self.collection.fetch(ids=ids)
|
||||
|
||||
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims)
|
||||
|
||||
@@ -285,3 +285,10 @@ class UpstashVector(VectorStoreBase):
|
||||
- Per-namespace vector and pending vector counts
|
||||
"""
|
||||
return self.client.info()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the Upstash Vector index.
|
||||
"""
|
||||
self.delete_col()
|
||||
|
||||
|
||||
@@ -620,3 +620,10 @@ class GoogleMatchingEngine(VectorStoreBase):
|
||||
logger.debug("Starting similarity search")
|
||||
docs_and_scores = self.similarity_search_with_score(query, k, filter)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the Google Matching Engine index.
|
||||
"""
|
||||
logger.warning("Reset operation is not supported for Google Matching Engine")
|
||||
pass
|
||||
|
||||
@@ -308,3 +308,9 @@ class Weaviate(VectorStoreBase):
|
||||
payload["id"] = str(obj.uuid).split("'")[0]
|
||||
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
|
||||
return [results]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the index by deleting and recreating it."""
|
||||
logger.warning(f"Resetting index {self.collection_name}...")
|
||||
self.delete_col()
|
||||
self.create_col()
|
||||
|
||||
Reference in New Issue
Block a user