Reset function for VectorDBs (#2584)
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent
|
import concurrent
|
||||||
import gc
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -20,7 +20,7 @@ from mem0.configs.prompts import (
|
|||||||
get_update_memory_messages,
|
get_update_memory_messages,
|
||||||
)
|
)
|
||||||
from mem0.memory.base import MemoryBase
|
from mem0.memory.base import MemoryBase
|
||||||
from mem0.memory.setup import setup_config, mem0_dir
|
from mem0.memory.setup import mem0_dir, setup_config
|
||||||
from mem0.memory.storage import SQLiteManager
|
from mem0.memory.storage import SQLiteManager
|
||||||
from mem0.memory.telemetry import capture_event
|
from mem0.memory.telemetry import capture_event
|
||||||
from mem0.memory.utils import (
|
from mem0.memory.utils import (
|
||||||
@@ -67,7 +67,7 @@ class Memory(MemoryBase):
|
|||||||
self.graph = MemoryGraph(self.config)
|
self.graph = MemoryGraph(self.config)
|
||||||
self.enable_graph = True
|
self.enable_graph = True
|
||||||
|
|
||||||
self.config.vector_store.config.collection_name = "mem0_migrations"
|
self.config.vector_store.config.collection_name = "mem0-migrations"
|
||||||
if self.config.vector_store.provider in ["faiss", "qdrant"]:
|
if self.config.vector_store.provider in ["faiss", "qdrant"]:
|
||||||
provider_path = f"migrations_{self.config.vector_store.provider}"
|
provider_path = f"migrations_{self.config.vector_store.provider}"
|
||||||
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
|
self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path)
|
||||||
@@ -765,13 +765,6 @@ class Memory(MemoryBase):
|
|||||||
Recreates the vector store with a new client
|
Recreates the vector store with a new client
|
||||||
"""
|
"""
|
||||||
logger.warning("Resetting all memories")
|
logger.warning("Resetting all memories")
|
||||||
self.vector_store.delete_col()
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# Close the client if it has a close method
|
|
||||||
if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'):
|
|
||||||
self.vector_store.client.close()
|
|
||||||
|
|
||||||
# Close the old connection if possible
|
# Close the old connection if possible
|
||||||
if hasattr(self.db, 'connection') and self.db.connection:
|
if hasattr(self.db, 'connection') and self.db.connection:
|
||||||
@@ -780,7 +773,11 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
self.db = SQLiteManager(self.config.history_db_path)
|
self.db = SQLiteManager(self.config.history_db_path)
|
||||||
|
|
||||||
# Create a new vector store with the same configuration
|
if hasattr(self.vector_store, 'reset'):
|
||||||
|
self.vector_store = VectorStoreFactory.reset(self.vector_store)
|
||||||
|
else:
|
||||||
|
logger.warning("Vector store does not support reset. Skipping.")
|
||||||
|
self.vector_store.delete_col()
|
||||||
self.vector_store = VectorStoreFactory.create(
|
self.vector_store = VectorStoreFactory.create(
|
||||||
self.config.vector_store.provider, self.config.vector_store.config
|
self.config.vector_store.provider, self.config.vector_store.config
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -97,3 +97,9 @@ class VectorStoreFactory:
|
|||||||
return vector_store_instance(**config)
|
return vector_store_instance(**config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
|
raise ValueError(f"Unsupported VectorStore provider: {provider_name}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset(cls, instance):
|
||||||
|
instance.reset()
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|||||||
@@ -65,6 +65,8 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
hybrid_search (bool): Whether to use hybrid search. Default is False.
|
||||||
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
|
||||||
"""
|
"""
|
||||||
|
self.service_name = service_name
|
||||||
|
self.api_key = api_key
|
||||||
self.index_name = collection_name
|
self.index_name = collection_name
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.embedding_model_dims = embedding_model_dims
|
self.embedding_model_dims = embedding_model_dims
|
||||||
@@ -341,3 +343,38 @@ class AzureAISearch(VectorStoreBase):
|
|||||||
"""Close the search client when the object is deleted."""
|
"""Close the search client when the object is deleted."""
|
||||||
self.search_client.close()
|
self.search_client.close()
|
||||||
self.index_client.close()
|
self.index_client.close()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.index_name}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Close the existing clients
|
||||||
|
self.search_client.close()
|
||||||
|
self.index_client.close()
|
||||||
|
|
||||||
|
# Delete the collection
|
||||||
|
self.delete_col()
|
||||||
|
|
||||||
|
# Reinitialize the clients
|
||||||
|
service_endpoint = f"https://{self.service_name}.search.windows.net"
|
||||||
|
self.search_client = SearchClient(
|
||||||
|
endpoint=service_endpoint,
|
||||||
|
index_name=self.index_name,
|
||||||
|
credential=AzureKeyCredential(self.api_key),
|
||||||
|
)
|
||||||
|
self.index_client = SearchIndexClient(
|
||||||
|
endpoint=service_endpoint,
|
||||||
|
credential=AzureKeyCredential(self.api_key),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add user agent
|
||||||
|
self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||||
|
self.index_client._client._config.user_agent_policy.add_user_agent("mem0")
|
||||||
|
|
||||||
|
# Create the collection
|
||||||
|
self.create_col()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error resetting index {self.index_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -51,3 +51,8 @@ class VectorStoreBase(ABC):
|
|||||||
def list(self, filters=None, limit=None):
|
def list(self, filters=None, limit=None):
|
||||||
"""List all memories."""
|
"""List all memories."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self):
|
||||||
|
"""Reset by delete the collection and recreate it."""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -221,3 +221,9 @@ class ChromaDB(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
results = self.collection.get(where=filters, limit=limit)
|
results = self.collection.get(where=filters, limit=limit)
|
||||||
return [self._parse_output(results)]
|
return [self._parse_output(results)]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.collection = self.create_col(self.collection_name)
|
||||||
|
|||||||
@@ -222,3 +222,9 @@ class ElasticsearchDB(VectorStoreBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_index()
|
||||||
|
|||||||
@@ -465,3 +465,9 @@ class FAISS(VectorStoreBase):
|
|||||||
break
|
break
|
||||||
|
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_col(self.collection_name)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -11,6 +12,7 @@ except ImportError:
|
|||||||
|
|
||||||
from mem0.vector_stores.base import VectorStoreBase
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class OutputData(BaseModel):
|
class OutputData(BaseModel):
|
||||||
id: Optional[str] # memory id
|
id: Optional[str] # memory id
|
||||||
@@ -135,6 +137,13 @@ class Langchain(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
Delete a collection.
|
Delete a collection.
|
||||||
"""
|
"""
|
||||||
|
logger.warning("Deleting collection")
|
||||||
|
if hasattr(self.client, "delete_collection"):
|
||||||
|
self.client.delete_collection()
|
||||||
|
elif hasattr(self.client, "reset_collection"):
|
||||||
|
self.client.reset_collection()
|
||||||
|
else:
|
||||||
|
# Fallback to the generic delete method
|
||||||
self.client.delete(ids=None)
|
self.client.delete(ids=None)
|
||||||
|
|
||||||
def col_info(self):
|
def col_info(self):
|
||||||
@@ -147,5 +156,27 @@ class Langchain(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
List all vectors in a collection.
|
List all vectors in a collection.
|
||||||
"""
|
"""
|
||||||
# This would require implementation-specific access to the underlying store
|
try:
|
||||||
raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores")
|
if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"):
|
||||||
|
# Convert mem0 filters to Chroma where clause if needed
|
||||||
|
where_clause = None
|
||||||
|
if filters and "user_id" in filters:
|
||||||
|
where_clause = {"user_id": filters["user_id"]}
|
||||||
|
|
||||||
|
result = self.client._collection.get(
|
||||||
|
where=where_clause,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the result to the expected format
|
||||||
|
if result and isinstance(result, dict):
|
||||||
|
return [self._parse_output(result)]
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing vectors from Chroma: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting collection: {self.collection_name}")
|
||||||
|
self.delete_col()
|
||||||
|
|||||||
@@ -237,3 +237,9 @@ class MilvusDB(VectorStoreBase):
|
|||||||
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata"))
|
||||||
memories.append(obj)
|
memories.append(obj)
|
||||||
return [memories]
|
return [memories]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type)
|
||||||
|
|||||||
@@ -199,3 +199,9 @@ class OpenSearchDB(VectorStoreBase):
|
|||||||
for hit in response["hits"]["hits"]
|
for hit in response["hits"]["hits"]
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_index()
|
||||||
|
|||||||
@@ -286,3 +286,9 @@ class PGVector(VectorStoreBase):
|
|||||||
self.cur.close()
|
self.cur.close()
|
||||||
if hasattr(self, "conn"):
|
if hasattr(self, "conn"):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_col(self.embedding_model_dims)
|
||||||
|
|||||||
@@ -369,5 +369,6 @@ class PineconeDB(VectorStoreBase):
|
|||||||
"""
|
"""
|
||||||
Reset the index by deleting and recreating it.
|
Reset the index by deleting and recreating it.
|
||||||
"""
|
"""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
self.delete_col()
|
self.delete_col()
|
||||||
self.create_col(self.embedding_model_dims, self.metric)
|
self.create_col(self.embedding_model_dims, self.metric)
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class Qdrant(VectorStoreBase):
|
|||||||
|
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self.embedding_model_dims = embedding_model_dims
|
self.embedding_model_dims = embedding_model_dims
|
||||||
|
self.on_disk = on_disk
|
||||||
self.create_col(embedding_model_dims, on_disk)
|
self.create_col(embedding_model_dims, on_disk)
|
||||||
|
|
||||||
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
|
def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE):
|
||||||
@@ -231,3 +232,9 @@ class Qdrant(VectorStoreBase):
|
|||||||
with_vectors=False,
|
with_vectors=False,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_col(self.embedding_model_dims, self.on_disk)
|
||||||
|
|||||||
@@ -75,9 +75,48 @@ class RedisDB(VectorStoreBase):
|
|||||||
self.index.set_client(self.client)
|
self.index.set_client(self.client)
|
||||||
self.index.create(overwrite=True)
|
self.index.create(overwrite=True)
|
||||||
|
|
||||||
# TODO: Implement multiindex support.
|
def create_col(self, name=None, vector_size=None, distance=None):
|
||||||
def create_col(self, name, vector_size, distance):
|
"""
|
||||||
raise NotImplementedError("Collection/Index creation not supported yet.")
|
Create a new collection (index) in Redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name.
|
||||||
|
vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims.
|
||||||
|
distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created index object.
|
||||||
|
"""
|
||||||
|
# Use provided parameters or fall back to instance attributes
|
||||||
|
collection_name = name or self.schema['index']['name']
|
||||||
|
embedding_dims = vector_size or self.embedding_model_dims
|
||||||
|
distance_metric = distance or "cosine"
|
||||||
|
|
||||||
|
# Create a new schema with the specified parameters
|
||||||
|
index_schema = {
|
||||||
|
"name": collection_name,
|
||||||
|
"prefix": f"mem0:{collection_name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Copy the default fields and update the vector field with the specified dimensions
|
||||||
|
fields = DEFAULT_FIELDS.copy()
|
||||||
|
fields[-1]["attrs"]["dims"] = embedding_dims
|
||||||
|
fields[-1]["attrs"]["distance_metric"] = distance_metric
|
||||||
|
|
||||||
|
# Create the schema
|
||||||
|
schema = {"index": index_schema, "fields": fields}
|
||||||
|
|
||||||
|
# Create the index
|
||||||
|
index = SearchIndex.from_dict(schema)
|
||||||
|
index.set_client(self.client)
|
||||||
|
index.create(overwrite=True)
|
||||||
|
|
||||||
|
# Update instance attributes if creating a new collection
|
||||||
|
if name:
|
||||||
|
self.schema = schema
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
def insert(self, vectors: list, payloads: list = None, ids: list = None):
|
||||||
data = []
|
data = []
|
||||||
@@ -194,6 +233,25 @@ class RedisDB(VectorStoreBase):
|
|||||||
def col_info(self, name):
|
def col_info(self, name):
|
||||||
return self.index.info()
|
return self.index.info()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the index by deleting and recreating it.
|
||||||
|
"""
|
||||||
|
collection_name = self.schema['index']['name']
|
||||||
|
logger.warning(f"Resetting index {collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
|
||||||
|
self.index = SearchIndex.from_dict(self.schema)
|
||||||
|
self.index.set_client(self.client)
|
||||||
|
self.index.create(overwrite=True)
|
||||||
|
|
||||||
|
#or use
|
||||||
|
#self.create_col(collection_name, self.embedding_model_dims)
|
||||||
|
|
||||||
|
|
||||||
|
# Recreate the index with the same parameters
|
||||||
|
self.create_col(collection_name, self.embedding_model_dims)
|
||||||
|
|
||||||
def list(self, filters: dict = None, limit: int = None) -> list:
|
def list(self, filters: dict = None, limit: int = None) -> list:
|
||||||
"""
|
"""
|
||||||
List all recent created memories from the vector store.
|
List all recent created memories from the vector store.
|
||||||
|
|||||||
@@ -229,3 +229,9 @@ class Supabase(VectorStoreBase):
|
|||||||
records = self.collection.fetch(ids=ids)
|
records = self.collection.fetch(ids=ids)
|
||||||
|
|
||||||
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_col(self.embedding_model_dims)
|
||||||
|
|||||||
@@ -285,3 +285,10 @@ class UpstashVector(VectorStoreBase):
|
|||||||
- Per-namespace vector and pending vector counts
|
- Per-namespace vector and pending vector counts
|
||||||
"""
|
"""
|
||||||
return self.client.info()
|
return self.client.info()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the Upstash Vector index.
|
||||||
|
"""
|
||||||
|
self.delete_col()
|
||||||
|
|
||||||
|
|||||||
@@ -620,3 +620,10 @@ class GoogleMatchingEngine(VectorStoreBase):
|
|||||||
logger.debug("Starting similarity search")
|
logger.debug("Starting similarity search")
|
||||||
docs_and_scores = self.similarity_search_with_score(query, k, filter)
|
docs_and_scores = self.similarity_search_with_score(query, k, filter)
|
||||||
return [doc for doc, _ in docs_and_scores]
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the Google Matching Engine index.
|
||||||
|
"""
|
||||||
|
logger.warning("Reset operation is not supported for Google Matching Engine")
|
||||||
|
pass
|
||||||
|
|||||||
@@ -308,3 +308,9 @@ class Weaviate(VectorStoreBase):
|
|||||||
payload["id"] = str(obj.uuid).split("'")[0]
|
payload["id"] = str(obj.uuid).split("'")[0]
|
||||||
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
|
results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload))
|
||||||
return [results]
|
return [results]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset the index by deleting and recreating it."""
|
||||||
|
logger.warning(f"Resetting index {self.collection_name}...")
|
||||||
|
self.delete_col()
|
||||||
|
self.create_col()
|
||||||
|
|||||||
Reference in New Issue
Block a user