diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 5b35a40f..9ff40f08 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -1,10 +1,10 @@ -import os import asyncio import concurrent import gc import hashlib import json import logging +import os import uuid import warnings from datetime import datetime @@ -20,7 +20,7 @@ from mem0.configs.prompts import ( get_update_memory_messages, ) from mem0.memory.base import MemoryBase -from mem0.memory.setup import setup_config, mem0_dir +from mem0.memory.setup import mem0_dir, setup_config from mem0.memory.storage import SQLiteManager from mem0.memory.telemetry import capture_event from mem0.memory.utils import ( @@ -67,7 +67,7 @@ class Memory(MemoryBase): self.graph = MemoryGraph(self.config) self.enable_graph = True - self.config.vector_store.config.collection_name = "mem0_migrations" + self.config.vector_store.config.collection_name = "mem0-migrations" if self.config.vector_store.provider in ["faiss", "qdrant"]: provider_path = f"migrations_{self.config.vector_store.provider}" self.config.vector_store.config.path = os.path.join(mem0_dir, provider_path) @@ -765,13 +765,6 @@ class Memory(MemoryBase): Recreates the vector store with a new client """ logger.warning("Resetting all memories") - self.vector_store.delete_col() - - gc.collect() - - # Close the client if it has a close method - if hasattr(self.vector_store, 'client') and hasattr(self.vector_store.client, 'close'): - self.vector_store.client.close() # Close the old connection if possible if hasattr(self.db, 'connection') and self.db.connection: @@ -780,10 +773,14 @@ class Memory(MemoryBase): self.db = SQLiteManager(self.config.history_db_path) - # Create a new vector store with the same configuration - self.vector_store = VectorStoreFactory.create( - self.config.vector_store.provider, self.config.vector_store.config - ) + if hasattr(self.vector_store, 'reset'): + self.vector_store = VectorStoreFactory.reset(self.vector_store) + else: + logger.warning("Vector store does not support reset. Skipping.") + self.vector_store.delete_col() + self.vector_store = VectorStoreFactory.create( + self.config.vector_store.provider, self.config.vector_store.config + ) capture_event("mem0.reset", self, {"sync_type": "sync"}) def chat(self, query): diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 8d48c3ae..03fc0d2c 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -97,3 +97,9 @@ class VectorStoreFactory: return vector_store_instance(**config) else: raise ValueError(f"Unsupported VectorStore provider: {provider_name}") + + @classmethod + def reset(cls, instance): + instance.reset() + return instance + diff --git a/mem0/vector_stores/azure_ai_search.py b/mem0/vector_stores/azure_ai_search.py index 99c53454..efcd528a 100644 --- a/mem0/vector_stores/azure_ai_search.py +++ b/mem0/vector_stores/azure_ai_search.py @@ -65,6 +65,8 @@ class AzureAISearch(VectorStoreBase): hybrid_search (bool): Whether to use hybrid search. Default is False. vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter". """ + self.service_name = service_name + self.api_key = api_key self.index_name = collection_name self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims @@ -341,3 +343,38 @@ class AzureAISearch(VectorStoreBase): """Close the search client when the object is deleted.""" self.search_client.close() self.index_client.close() + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.index_name}...") + + try: + # Close the existing clients + self.search_client.close() + self.index_client.close() + + # Delete the collection + self.delete_col() + + # Reinitialize the clients + service_endpoint = f"https://{self.service_name}.search.windows.net" + self.search_client = SearchClient( + endpoint=service_endpoint, + index_name=self.index_name, + credential=AzureKeyCredential(self.api_key), + ) + self.index_client = SearchIndexClient( + endpoint=service_endpoint, + credential=AzureKeyCredential(self.api_key), + ) + + # Add user agent + self.search_client._client._config.user_agent_policy.add_user_agent("mem0") + self.index_client._client._config.user_agent_policy.add_user_agent("mem0") + + # Create the collection + self.create_col() + except Exception as e: + logger.error(f"Error resetting index {self.index_name}: {e}") + raise + diff --git a/mem0/vector_stores/base.py b/mem0/vector_stores/base.py index 4f55d109..98f3503b 100644 --- a/mem0/vector_stores/base.py +++ b/mem0/vector_stores/base.py @@ -51,3 +51,8 @@ class VectorStoreBase(ABC): def list(self, filters=None, limit=None): """List all memories.""" pass + + @abstractmethod + def reset(self): + """Reset by delete the collection and recreate it.""" + pass diff --git a/mem0/vector_stores/chroma.py b/mem0/vector_stores/chroma.py index 696a3047..ae4b03f7 100644 --- a/mem0/vector_stores/chroma.py +++ b/mem0/vector_stores/chroma.py @@ -221,3 +221,9 @@ class ChromaDB(VectorStoreBase): """ results = self.collection.get(where=filters, limit=limit) return [self._parse_output(results)] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.collection = self.create_col(self.collection_name) diff --git a/mem0/vector_stores/elasticsearch.py b/mem0/vector_stores/elasticsearch.py index 7b7cc2b4..4d733a45 100644 --- a/mem0/vector_stores/elasticsearch.py +++ b/mem0/vector_stores/elasticsearch.py @@ -222,3 +222,9 @@ class ElasticsearchDB(VectorStoreBase): ) return [results] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_index() diff --git a/mem0/vector_stores/faiss.py b/mem0/vector_stores/faiss.py index 623a3bd7..c042738c 100644 --- a/mem0/vector_stores/faiss.py +++ b/mem0/vector_stores/faiss.py @@ -465,3 +465,9 @@ class FAISS(VectorStoreBase): break return [results] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_col(self.collection_name) diff --git a/mem0/vector_stores/langchain.py b/mem0/vector_stores/langchain.py index ad9a991d..aac04f06 100644 --- a/mem0/vector_stores/langchain.py +++ b/mem0/vector_stores/langchain.py @@ -1,3 +1,4 @@ +import logging from typing import Dict, List, Optional from pydantic import BaseModel @@ -11,6 +12,7 @@ except ImportError: from mem0.vector_stores.base import VectorStoreBase +logger = logging.getLogger(__name__) class OutputData(BaseModel): id: Optional[str] # memory id @@ -135,7 +137,14 @@ class Langchain(VectorStoreBase): """ Delete a collection. """ - self.client.delete(ids=None) + logger.warning("Deleting collection") + if hasattr(self.client, "delete_collection"): + self.client.delete_collection() + elif hasattr(self.client, "reset_collection"): + self.client.reset_collection() + else: + # Fallback to the generic delete method + self.client.delete(ids=None) def col_info(self): """ @@ -147,5 +156,27 @@ class Langchain(VectorStoreBase): """ List all vectors in a collection. """ - # This would require implementation-specific access to the underlying store - raise NotImplementedError("Listing all vectors not directly supported by LangChain vectorstores") + try: + if hasattr(self.client, "_collection") and hasattr(self.client._collection, "get"): + # Convert mem0 filters to Chroma where clause if needed + where_clause = None + if filters and "user_id" in filters: + where_clause = {"user_id": filters["user_id"]} + + result = self.client._collection.get( + where=where_clause, + limit=limit + ) + + # Convert the result to the expected format + if result and isinstance(result, dict): + return [self._parse_output(result)] + return [] + except Exception as e: + logger.error(f"Error listing vectors from Chroma: {e}") + return [] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting collection: {self.collection_name}") + self.delete_col() diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py index ff48e306..775006ff 100644 --- a/mem0/vector_stores/milvus.py +++ b/mem0/vector_stores/milvus.py @@ -237,3 +237,9 @@ class MilvusDB(VectorStoreBase): obj = OutputData(id=data.get("id"), score=None, payload=data.get("metadata")) memories.append(obj) return [memories] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_col(self.collection_name, self.embedding_model_dims, self.metric_type) diff --git a/mem0/vector_stores/opensearch.py b/mem0/vector_stores/opensearch.py index 18b0063c..72d39976 100644 --- a/mem0/vector_stores/opensearch.py +++ b/mem0/vector_stores/opensearch.py @@ -199,3 +199,9 @@ class OpenSearchDB(VectorStoreBase): for hit in response["hits"]["hits"] ] ] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_index() diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py index 85ff3e14..cc33077f 100644 --- a/mem0/vector_stores/pgvector.py +++ b/mem0/vector_stores/pgvector.py @@ -286,3 +286,9 @@ class PGVector(VectorStoreBase): self.cur.close() if hasattr(self, "conn"): self.conn.close() + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_col(self.embedding_model_dims) diff --git a/mem0/vector_stores/pinecone.py b/mem0/vector_stores/pinecone.py index 1cb0be5a..ff350bb4 100644 --- a/mem0/vector_stores/pinecone.py +++ b/mem0/vector_stores/pinecone.py @@ -369,5 +369,6 @@ class PineconeDB(VectorStoreBase): """ Reset the index by deleting and recreating it. """ + logger.warning(f"Resetting index {self.collection_name}...") self.delete_col() self.create_col(self.embedding_model_dims, self.metric) diff --git a/mem0/vector_stores/qdrant.py b/mem0/vector_stores/qdrant.py index d50054ca..3703878c 100644 --- a/mem0/vector_stores/qdrant.py +++ b/mem0/vector_stores/qdrant.py @@ -67,6 +67,7 @@ class Qdrant(VectorStoreBase): self.collection_name = collection_name self.embedding_model_dims = embedding_model_dims + self.on_disk = on_disk self.create_col(embedding_model_dims, on_disk) def create_col(self, vector_size: int, on_disk: bool, distance: Distance = Distance.COSINE): @@ -231,3 +232,9 @@ class Qdrant(VectorStoreBase): with_vectors=False, ) return result + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_col(self.embedding_model_dims, self.on_disk) diff --git a/mem0/vector_stores/redis.py b/mem0/vector_stores/redis.py index 25cdcbb5..293d69ec 100644 --- a/mem0/vector_stores/redis.py +++ b/mem0/vector_stores/redis.py @@ -75,9 +75,48 @@ class RedisDB(VectorStoreBase): self.index.set_client(self.client) self.index.create(overwrite=True) - # TODO: Implement multiindex support. - def create_col(self, name, vector_size, distance): - raise NotImplementedError("Collection/Index creation not supported yet.") + def create_col(self, name=None, vector_size=None, distance=None): + """ + Create a new collection (index) in Redis. + + Args: + name (str, optional): Name for the collection. Defaults to None, which uses the current collection_name. + vector_size (int, optional): Size of the vector embeddings. Defaults to None, which uses the current embedding_model_dims. + distance (str, optional): Distance metric to use. Defaults to None, which uses 'cosine'. + + Returns: + The created index object. + """ + # Use provided parameters or fall back to instance attributes + collection_name = name or self.schema['index']['name'] + embedding_dims = vector_size or self.embedding_model_dims + distance_metric = distance or "cosine" + + # Create a new schema with the specified parameters + index_schema = { + "name": collection_name, + "prefix": f"mem0:{collection_name}", + } + + # Copy the default fields and update the vector field with the specified dimensions + fields = DEFAULT_FIELDS.copy() + fields[-1]["attrs"]["dims"] = embedding_dims + fields[-1]["attrs"]["distance_metric"] = distance_metric + + # Create the schema + schema = {"index": index_schema, "fields": fields} + + # Create the index + index = SearchIndex.from_dict(schema) + index.set_client(self.client) + index.create(overwrite=True) + + # Update instance attributes if creating a new collection + if name: + self.schema = schema + self.index = index + + return index def insert(self, vectors: list, payloads: list = None, ids: list = None): data = [] @@ -194,6 +233,25 @@ class RedisDB(VectorStoreBase): def col_info(self, name): return self.index.info() + def reset(self): + """ + Reset the index by deleting and recreating it. + """ + collection_name = self.schema['index']['name'] + logger.warning(f"Resetting index {collection_name}...") + self.delete_col() + + self.index = SearchIndex.from_dict(self.schema) + self.index.set_client(self.client) + self.index.create(overwrite=True) + + #or use + #self.create_col(collection_name, self.embedding_model_dims) + + + # Recreate the index with the same parameters + self.create_col(collection_name, self.embedding_model_dims) + def list(self, filters: dict = None, limit: int = None) -> list: """ List all recent created memories from the vector store. diff --git a/mem0/vector_stores/supabase.py b/mem0/vector_stores/supabase.py index 65ea3684..9d0053d1 100644 --- a/mem0/vector_stores/supabase.py +++ b/mem0/vector_stores/supabase.py @@ -229,3 +229,9 @@ class Supabase(VectorStoreBase): records = self.collection.fetch(ids=ids) return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_col(self.embedding_model_dims) diff --git a/mem0/vector_stores/upstash_vector.py b/mem0/vector_stores/upstash_vector.py index 66a010db..6d9b6a06 100644 --- a/mem0/vector_stores/upstash_vector.py +++ b/mem0/vector_stores/upstash_vector.py @@ -285,3 +285,10 @@ class UpstashVector(VectorStoreBase): - Per-namespace vector and pending vector counts """ return self.client.info() + + def reset(self): + """ + Reset the Upstash Vector index. + """ + self.delete_col() + diff --git a/mem0/vector_stores/vertex_ai_vector_search.py b/mem0/vector_stores/vertex_ai_vector_search.py index 6f526584..39aa9923 100644 --- a/mem0/vector_stores/vertex_ai_vector_search.py +++ b/mem0/vector_stores/vertex_ai_vector_search.py @@ -620,3 +620,10 @@ class GoogleMatchingEngine(VectorStoreBase): logger.debug("Starting similarity search") docs_and_scores = self.similarity_search_with_score(query, k, filter) return [doc for doc, _ in docs_and_scores] + + def reset(self): + """ + Reset the Google Matching Engine index. + """ + logger.warning("Reset operation is not supported for Google Matching Engine") + pass diff --git a/mem0/vector_stores/weaviate.py b/mem0/vector_stores/weaviate.py index cdf20dee..b759780b 100644 --- a/mem0/vector_stores/weaviate.py +++ b/mem0/vector_stores/weaviate.py @@ -308,3 +308,9 @@ class Weaviate(VectorStoreBase): payload["id"] = str(obj.uuid).split("'")[0] results.append(OutputData(id=str(obj.uuid).split("'")[0], score=1.0, payload=payload)) return [results] + + def reset(self): + """Reset the index by deleting and recreating it.""" + logger.warning(f"Resetting index {self.collection_name}...") + self.delete_col() + self.create_col()