From 9ae23f9c88bc3ccecefc694de3d580924f7e5df7 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sat, 29 Mar 2025 13:35:36 +0530 Subject: [PATCH] Add Faiss Support (#2461) --- Makefile | 2 +- docs/components/vectordbs/dbs/faiss.mdx | 72 ++++ docs/components/vectordbs/overview.mdx | 1 + mem0/configs/vector_stores/faiss.py | 38 ++ mem0/utils/factory.py | 1 + mem0/vector_stores/configs.py | 1 + mem0/vector_stores/faiss.py | 464 ++++++++++++++++++++++++ pyproject.toml | 2 +- tests/vector_stores/test_faiss.py | 314 ++++++++++++++++ 9 files changed, 893 insertions(+), 2 deletions(-) create mode 100644 docs/components/vectordbs/dbs/faiss.mdx create mode 100644 mem0/configs/vector_stores/faiss.py create mode 100644 mem0/vector_stores/faiss.py create mode 100644 tests/vector_stores/test_faiss.py diff --git a/Makefile b/Makefile index 3fb584f3..76e757b1 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ install: install_all: poetry install poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \ - google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text + google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu # Format code with ruff format: diff --git a/docs/components/vectordbs/dbs/faiss.mdx b/docs/components/vectordbs/dbs/faiss.mdx new file mode 100644 index 00000000..19daddab --- /dev/null +++ b/docs/components/vectordbs/dbs/faiss.mdx @@ -0,0 +1,72 @@ +[FAISS](https://github.com/facebookresearch/faiss) is a library for efficient similarity search and clustering of dense vectors. It is designed to work with large-scale datasets and provides a high-performance search engine for vector data. FAISS is optimized for memory usage and search speed, making it an excellent choice for production environments. + +### Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "faiss", + "config": { + "collection_name": "test", + "path": "/tmp/faiss_memories", + "distance_strategy": "euclidean" + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + +### Installation + +To use FAISS in your mem0 project, you need to install the appropriate FAISS package for your environment: + +```bash +# For CPU version +pip install faiss-cpu + +# For GPU version (requires CUDA) +pip install faiss-gpu +``` + +### Config + +Here are the parameters available for configuring FAISS: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `collection_name` | The name of the collection | `mem0` | +| `path` | Path to store FAISS index and metadata | `/tmp/faiss/` | +| `distance_strategy` | Distance metric strategy to use (options: 'euclidean', 'inner_product', 'cosine') | `euclidean` | +| `normalize_L2` | Whether to normalize L2 vectors (only applicable for euclidean distance) | `False` | + +### Performance Considerations + +FAISS offers several advantages for vector search: + +1. **Efficiency**: FAISS is optimized for memory usage and speed, making it suitable for large-scale applications. +2. **Offline Support**: FAISS works entirely locally, with no need for external servers or API calls. +3. **Storage Options**: Vectors can be stored in-memory for maximum speed or persisted to disk. +4. **Multiple Index Types**: FAISS supports different index types optimized for various use cases (though mem0 currently uses the basic flat index). + +### Distance Strategies + +FAISS in mem0 supports three distance strategies: + +- **euclidean**: L2 distance, suitable for most embedding models +- **inner_product**: Dot product similarity, useful for some specialized embeddings +- **cosine**: Cosine similarity, best for comparing semantic similarity regardless of vector magnitude + +When using `cosine` or `inner_product` with normalized vectors, you may want to set `normalize_L2=True` for better results. diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index 0e169f10..0af2c5ae 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -27,6 +27,7 @@ See the list of supported vector databases below. + ## Usage diff --git a/mem0/configs/vector_stores/faiss.py b/mem0/configs/vector_stores/faiss.py new file mode 100644 index 00000000..064585c6 --- /dev/null +++ b/mem0/configs/vector_stores/faiss.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, model_validator + + +class FAISSConfig(BaseModel): + collection_name: str = Field("mem0", description="Default name for the collection") + path: Optional[str] = Field(None, description="Path to store FAISS index and metadata") + distance_strategy: str = Field( + "euclidean", description="Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'" + ) + normalize_L2: bool = Field( + False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)" + ) + + @model_validator(mode="before") + @classmethod + def validate_distance_strategy(cls, values: Dict[str, Any]) -> Dict[str, Any]: + distance_strategy = values.get("distance_strategy") + if distance_strategy and distance_strategy not in ["euclidean", "inner_product", "cosine"]: + raise ValueError("Invalid distance_strategy. Must be one of: 'euclidean', 'inner_product', 'cosine'") + return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 0d7a377f..0bb76797 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -76,6 +76,7 @@ class VectorStoreFactory: "opensearch": "mem0.vector_stores.opensearch.OpenSearchDB", "supabase": "mem0.vector_stores.supabase.Supabase", "weaviate": "mem0.vector_stores.weaviate.Weaviate", + "faiss": "mem0.vector_stores.faiss.FAISS", } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index db7a75d3..649dc2d5 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -23,6 +23,7 @@ class VectorStoreConfig(BaseModel): "opensearch": "OpenSearchConfig", "supabase": "SupabaseConfig", "weaviate": "WeaviateConfig", + "faiss": "FAISSConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/faiss.py b/mem0/vector_stores/faiss.py new file mode 100644 index 00000000..b0fc928b --- /dev/null +++ b/mem0/vector_stores/faiss.py @@ -0,0 +1,464 @@ +import logging +import os +import pickle +import uuid +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +from pydantic import BaseModel + +try: + import faiss +except ImportError: + raise ImportError( + "Could not import faiss python package. " + "Please install it with `pip install faiss-gpu` (for CUDA supported GPU) " + "or `pip install faiss-cpu` (depending on Python version)." + ) + +from mem0.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] # memory id + score: Optional[float] # distance + payload: Optional[Dict] # metadata + + +class FAISS(VectorStoreBase): + def __init__( + self, + collection_name: str, + path: Optional[str] = None, + distance_strategy: str = "euclidean", + normalize_L2: bool = False, + ): + """ + Initialize the FAISS vector store. + + Args: + collection_name (str): Name of the collection. + path (str, optional): Path for local FAISS database. Defaults to None. + distance_strategy (str, optional): Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'. + Defaults to "euclidean". + normalize_L2 (bool, optional): Whether to normalize L2 vectors. Only applicable for euclidean distance. + Defaults to False. + """ + self.collection_name = collection_name + self.path = path or f"/tmp/faiss/{collection_name}" + self.distance_strategy = distance_strategy + self.normalize_L2 = normalize_L2 + + # Initialize storage structures + self.index = None + self.docstore = {} + self.index_to_id = {} + + # Create directory if it doesn't exist + if self.path: + os.makedirs(os.path.dirname(self.path), exist_ok=True) + + # Try to load existing index if available + index_path = f"{self.path}/{collection_name}.faiss" + docstore_path = f"{self.path}/{collection_name}.pkl" + if os.path.exists(index_path) and os.path.exists(docstore_path): + self._load(index_path, docstore_path) + else: + self.create_col(collection_name) + + def _load(self, index_path: str, docstore_path: str): + """ + Load FAISS index and docstore from disk. + + Args: + index_path (str): Path to FAISS index file. + docstore_path (str): Path to docstore pickle file. + """ + try: + self.index = faiss.read_index(index_path) + with open(docstore_path, "rb") as f: + self.docstore, self.index_to_id = pickle.load(f) + logger.info(f"Loaded FAISS index from {index_path} with {self.index.ntotal} vectors") + except Exception as e: + logger.warning(f"Failed to load FAISS index: {e}") + + self.docstore = {} + self.index_to_id = {} + + def _save(self): + """Save FAISS index and docstore to disk.""" + if not self.path or not self.index: + return + + try: + os.makedirs(self.path, exist_ok=True) + index_path = f"{self.path}/{self.collection_name}.faiss" + docstore_path = f"{self.path}/{self.collection_name}.pkl" + + faiss.write_index(self.index, index_path) + with open(docstore_path, "wb") as f: + pickle.dump((self.docstore, self.index_to_id), f) + logger.info(f"Saved FAISS index to {index_path} with {self.index.ntotal} vectors") + except Exception as e: + logger.warning(f"Failed to save FAISS index: {e}") + + def _parse_output(self, scores, ids, limit=None) -> List[OutputData]: + """ + Parse the output data. + + Args: + scores: Similarity scores from FAISS. + ids: Indices from FAISS. + limit: Maximum number of results to return. + + Returns: + List[OutputData]: Parsed output data. + """ + if limit is None: + limit = len(ids) + + results = [] + for i in range(min(len(ids), limit)): + if ids[i] == -1: # FAISS returns -1 for empty results + continue + + index_id = int(ids[i]) + vector_id = self.index_to_id.get(index_id) + if vector_id is None: + continue + + payload = self.docstore.get(vector_id) + if payload is None: + continue + + payload_copy = payload.copy() + + score = float(scores[i]) + entry = OutputData( + id=vector_id, + score=score, + payload=payload_copy, + ) + results.append(entry) + + return results + + def create_col(self, name: str, vector_size: int = 1536, distance: str = None): + """ + Create a new collection. + + Args: + name (str): Name of the collection. + vector_size (int, optional): Dimensionality of vectors. Defaults to 1536. + distance (str, optional): Distance metric to use. Overrides the distance_strategy + passed during initialization. Defaults to None. + + Returns: + self: The FAISS instance. + """ + distance_strategy = distance or self.distance_strategy + + # Create index based on distance strategy + if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine": + self.index = faiss.IndexFlatIP(vector_size) + else: + self.index = faiss.IndexFlatL2(vector_size) + + self.collection_name = name + + self._save() + + return self + + def insert( + self, + vectors: List[list], + payloads: Optional[List[Dict]] = None, + ids: Optional[List[str]] = None, + ): + """ + Insert vectors into a collection. + + Args: + vectors (List[list]): List of vectors to insert. + payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None. + ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None. + """ + if self.index is None: + raise ValueError("Collection not initialized. Call create_col first.") + + if ids is None: + ids = [str(uuid.uuid4()) for _ in range(len(vectors))] + + if payloads is None: + payloads = [{} for _ in range(len(vectors))] + + if len(vectors) != len(ids) or len(vectors) != len(payloads): + raise ValueError("Vectors, payloads, and IDs must have the same length") + + vectors_np = np.array(vectors, dtype=np.float32) + + if self.normalize_L2 and self.distance_strategy.lower() == "euclidean": + faiss.normalize_L2(vectors_np) + + self.index.add(vectors_np) + + starting_idx = len(self.index_to_id) + for i, (vector_id, payload) in enumerate(zip(ids, payloads)): + self.docstore[vector_id] = payload.copy() + self.index_to_id[starting_idx + i] = vector_id + + self._save() + + logger.info(f"Inserted {len(vectors)} vectors into collection {self.collection_name}") + + def search( + self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None + ) -> List[OutputData]: + """ + Search for similar vectors. + + Args: + query (str): Query (not used, kept for API compatibility). + vectors (List[list]): List of vectors to search. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. + + Returns: + List[OutputData]: Search results. + """ + if self.index is None: + raise ValueError("Collection not initialized. Call create_col first.") + + query_vectors = np.array(vectors, dtype=np.float32) + + if len(query_vectors.shape) == 1: + query_vectors = query_vectors.reshape(1, -1) + + if self.normalize_L2 and self.distance_strategy.lower() == "euclidean": + faiss.normalize_L2(query_vectors) + + fetch_k = limit * 2 if filters else limit + scores, indices = self.index.search(query_vectors, fetch_k) + + results = self._parse_output(scores[0], indices[0], limit) + + if filters: + filtered_results = [] + for result in results: + if self._apply_filters(result.payload, filters): + filtered_results.append(result) + if len(filtered_results) >= limit: + break + results = filtered_results[:limit] + + return results + + def _apply_filters(self, payload: Dict, filters: Dict) -> bool: + """ + Apply filters to a payload. + + Args: + payload (Dict): Payload to filter. + filters (Dict): Filters to apply. + + Returns: + bool: True if payload passes filters, False otherwise. + """ + if not filters or not payload: + return True + + for key, value in filters.items(): + if key not in payload: + return False + + if isinstance(value, list): + if payload[key] not in value: + return False + elif payload[key] != value: + return False + + return True + + def delete(self, vector_id: str): + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete. + """ + if self.index is None: + raise ValueError("Collection not initialized. Call create_col first.") + + index_to_delete = None + for idx, vid in self.index_to_id.items(): + if vid == vector_id: + index_to_delete = idx + break + + if index_to_delete is not None: + self.docstore.pop(vector_id, None) + self.index_to_id.pop(index_to_delete, None) + + self._save() + + logger.info(f"Deleted vector {vector_id} from collection {self.collection_name}") + else: + logger.warning(f"Vector {vector_id} not found in collection {self.collection_name}") + + def update( + self, + vector_id: str, + vector: Optional[List[float]] = None, + payload: Optional[Dict] = None, + ): + """ + Update a vector and its payload. + + Args: + vector_id (str): ID of the vector to update. + vector (Optional[List[float]], optional): Updated vector. Defaults to None. + payload (Optional[Dict], optional): Updated payload. Defaults to None. + """ + if self.index is None: + raise ValueError("Collection not initialized. Call create_col first.") + + if vector_id not in self.docstore: + raise ValueError(f"Vector {vector_id} not found") + + current_payload = self.docstore[vector_id].copy() + + if payload is not None: + self.docstore[vector_id] = payload.copy() + current_payload = self.docstore[vector_id].copy() + + if vector is not None: + self.delete(vector_id) + self.insert([vector], [current_payload], [vector_id]) + else: + self._save() + + logger.info(f"Updated vector {vector_id} in collection {self.collection_name}") + + def get(self, vector_id: str) -> OutputData: + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve. + + Returns: + OutputData: Retrieved vector. + """ + if self.index is None: + raise ValueError("Collection not initialized. Call create_col first.") + + if vector_id not in self.docstore: + return None + + payload = self.docstore[vector_id].copy() + + return OutputData( + id=vector_id, + score=None, + payload=payload, + ) + + def list_cols(self) -> List[str]: + """ + List all collections. + + Returns: + List[str]: List of collection names. + """ + if not self.path: + return [self.collection_name] if self.index else [] + + try: + collections = [] + path = Path(self.path).parent + for file in path.glob("*.faiss"): + collections.append(file.stem) + return collections + except Exception as e: + logger.warning(f"Failed to list collections: {e}") + return [self.collection_name] if self.index else [] + + def delete_col(self): + """ + Delete a collection. + """ + if self.path: + try: + index_path = f"{self.path}/{self.collection_name}.faiss" + docstore_path = f"{self.path}/{self.collection_name}.pkl" + + if os.path.exists(index_path): + os.remove(index_path) + if os.path.exists(docstore_path): + os.remove(docstore_path) + + logger.info(f"Deleted collection {self.collection_name}") + except Exception as e: + logger.warning(f"Failed to delete collection: {e}") + + self.index = None + self.docstore = {} + self.index_to_id = {} + + def col_info(self) -> Dict: + """ + Get information about a collection. + + Returns: + Dict: Collection information. + """ + if self.index is None: + return {"name": self.collection_name, "count": 0} + + return { + "name": self.collection_name, + "count": self.index.ntotal, + "dimension": self.index.d, + "distance": self.distance_strategy, + } + + def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]: + """ + List all vectors in a collection. + + Args: + filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None. + limit (int, optional): Number of vectors to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors. + """ + if self.index is None: + return [] + + results = [] + count = 0 + + for vector_id, payload in self.docstore.items(): + if filters and not self._apply_filters(payload, filters): + continue + + payload_copy = payload.copy() + + results.append( + OutputData( + id=vector_id, + score=None, + payload=payload_copy, + ) + ) + + count += 1 + if count >= limit: + break + + return [results] diff --git a/pyproject.toml b/pyproject.toml index 7b35c85b..3c84e6df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "mem0ai" -version = "0.1.77" +version = "0.1.78" description = "Long-term memory for AI Agents" authors = ["Mem0 "] exclude = [ diff --git a/tests/vector_stores/test_faiss.py b/tests/vector_stores/test_faiss.py new file mode 100644 index 00000000..c4fcd800 --- /dev/null +++ b/tests/vector_stores/test_faiss.py @@ -0,0 +1,314 @@ +import os +import tempfile +from unittest.mock import Mock, patch + +import faiss +import numpy as np +import pytest + +from mem0.vector_stores.faiss import FAISS, OutputData + + +@pytest.fixture +def mock_faiss_index(): + index = Mock(spec=faiss.IndexFlatL2) + index.d = 128 # Dimension of the vectors + index.ntotal = 0 # Number of vectors in the index + return index + + +@pytest.fixture +def faiss_instance(mock_faiss_index): + with tempfile.TemporaryDirectory() as temp_dir: + # Mock the faiss index creation + with patch('faiss.IndexFlatL2', return_value=mock_faiss_index): + # Mock the faiss.write_index function + with patch('faiss.write_index'): + # Create a FAISS instance with a temporary directory + faiss_store = FAISS( + collection_name="test_collection", + path=os.path.join(temp_dir, "test_faiss"), + distance_strategy="euclidean", + ) + # Set up the mock index + faiss_store.index = mock_faiss_index + yield faiss_store + + +def test_create_col(faiss_instance, mock_faiss_index): + # Test creating a collection with euclidean distance + with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2: + with patch('faiss.write_index'): + faiss_instance.create_col(name="new_collection", vector_size=256) + mock_index_flat_l2.assert_called_once_with(256) + + # Test creating a collection with inner product distance + with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip: + with patch('faiss.write_index'): + faiss_instance.create_col(name="new_collection", vector_size=256, distance="inner_product") + mock_index_flat_ip.assert_called_once_with(256) + + +def test_insert(faiss_instance, mock_faiss_index): + # Prepare test data + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + + # Mock the numpy array conversion + with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array: + # Mock index.add + mock_faiss_index.add.return_value = None + + # Call insert + faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + + # Verify numpy.array was called + mock_np_array.assert_called_once_with(vectors, dtype=np.float32) + + # Verify index.add was called + mock_faiss_index.add.assert_called_once() + + # Verify docstore and index_to_id were updated + assert faiss_instance.docstore["id1"] == {"name": "vector1"} + assert faiss_instance.docstore["id2"] == {"name": "vector2"} + assert faiss_instance.index_to_id[0] == "id1" + assert faiss_instance.index_to_id[1] == "id2" + + +def test_search(faiss_instance, mock_faiss_index): + # Prepare test data + query_vector = [0.1, 0.2, 0.3] + + # Setup the docstore and index_to_id mapping + faiss_instance.docstore = { + "id1": {"name": "vector1"}, + "id2": {"name": "vector2"} + } + faiss_instance.index_to_id = {0: "id1", 1: "id2"} + + # First, create the mock for the search return values + search_scores = np.array([[0.9, 0.8]]) + search_indices = np.array([[0, 1]]) + mock_faiss_index.search.return_value = (search_scores, search_indices) + + # Then patch numpy.array only for the query vector conversion + with patch('numpy.array') as mock_np_array: + mock_np_array.return_value = np.array(query_vector, dtype=np.float32) + + # Then patch _parse_output to return the expected results + expected_results = [ + OutputData(id="id1", score=0.9, payload={"name": "vector1"}), + OutputData(id="id2", score=0.8, payload={"name": "vector2"}) + ] + + with patch.object(faiss_instance, '_parse_output', return_value=expected_results): + # Call search + results = faiss_instance.search(query="test query", vectors=query_vector, limit=2) + + # Verify numpy.array was called (but we don't check exact call arguments since it's complex) + assert mock_np_array.called + + # Verify index.search was called + mock_faiss_index.search.assert_called_once() + + # Verify results + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.9 + assert results[0].payload == {"name": "vector1"} + assert results[1].id == "id2" + assert results[1].score == 0.8 + assert results[1].payload == {"name": "vector2"} + + +def test_search_with_filters(faiss_instance, mock_faiss_index): + # Prepare test data + query_vector = [0.1, 0.2, 0.3] + + # Setup the docstore and index_to_id mapping + faiss_instance.docstore = { + "id1": {"name": "vector1", "category": "A"}, + "id2": {"name": "vector2", "category": "B"} + } + faiss_instance.index_to_id = {0: "id1", 1: "id2"} + + # First set up the search return values + search_scores = np.array([[0.9, 0.8]]) + search_indices = np.array([[0, 1]]) + mock_faiss_index.search.return_value = (search_scores, search_indices) + + # Patch numpy.array for query vector conversion + with patch('numpy.array') as mock_np_array: + mock_np_array.return_value = np.array(query_vector, dtype=np.float32) + + # Directly mock the _parse_output method to return our expected values + # We're simulating that _parse_output filters to just the first result + all_results = [ + OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}), + OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"}) + ] + + filtered_results = [all_results[0]] # Just the "category": "A" result + + # Create a side_effect function that returns all results first (for _parse_output) + # then returns filtered results (for the filters) + parse_output_mock = Mock(side_effect=[all_results, filtered_results]) + + # Replace the _apply_filters method to handle our test case + with patch.object(faiss_instance, '_parse_output', return_value=all_results): + with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"): + # Call search with filters + results = faiss_instance.search( + query="test query", + vectors=query_vector, + limit=2, + filters={"category": "A"} + ) + + # Verify numpy.array was called + assert mock_np_array.called + + # Verify index.search was called + mock_faiss_index.search.assert_called_once() + + # Verify filtered results - since we've mocked everything, + # we should get just the result we want + assert len(results) == 1 + assert results[0].id == "id1" + assert results[0].score == 0.9 + assert results[0].payload == {"name": "vector1", "category": "A"} + + +def test_delete(faiss_instance): + # Setup the docstore and index_to_id mapping + faiss_instance.docstore = { + "id1": {"name": "vector1"}, + "id2": {"name": "vector2"} + } + faiss_instance.index_to_id = {0: "id1", 1: "id2"} + + # Call delete + faiss_instance.delete(vector_id="id1") + + # Verify the vector was removed from docstore and index_to_id + assert "id1" not in faiss_instance.docstore + assert 0 not in faiss_instance.index_to_id + assert "id2" in faiss_instance.docstore + assert 1 in faiss_instance.index_to_id + + +def test_update(faiss_instance, mock_faiss_index): + # Setup the docstore and index_to_id mapping + faiss_instance.docstore = { + "id1": {"name": "vector1"}, + "id2": {"name": "vector2"} + } + faiss_instance.index_to_id = {0: "id1", 1: "id2"} + + # Test updating payload only + faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"}) + assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"} + + # Test updating vector + # This requires mocking the delete and insert methods + with patch.object(faiss_instance, 'delete') as mock_delete: + with patch.object(faiss_instance, 'insert') as mock_insert: + new_vector = [0.7, 0.8, 0.9] + faiss_instance.update(vector_id="id2", vector=new_vector) + + # Verify delete and insert were called + # Match the actual call signature (positional arg instead of keyword) + mock_delete.assert_called_once_with("id2") + mock_insert.assert_called_once() + + +def test_get(faiss_instance): + # Setup the docstore + faiss_instance.docstore = { + "id1": {"name": "vector1"}, + "id2": {"name": "vector2"} + } + + # Test getting an existing vector + result = faiss_instance.get(vector_id="id1") + assert result.id == "id1" + assert result.payload == {"name": "vector1"} + assert result.score is None + + # Test getting a non-existent vector + result = faiss_instance.get(vector_id="id3") + assert result is None + + +def test_list(faiss_instance): + # Setup the docstore + faiss_instance.docstore = { + "id1": {"name": "vector1", "category": "A"}, + "id2": {"name": "vector2", "category": "B"}, + "id3": {"name": "vector3", "category": "A"} + } + + # Test listing all vectors + results = faiss_instance.list() + # Fix the expected result - the list method returns a list of lists + assert len(results[0]) == 3 + + # Test listing with a limit + results = faiss_instance.list(limit=2) + assert len(results[0]) == 2 + + # Test listing with filters + results = faiss_instance.list(filters={"category": "A"}) + assert len(results[0]) == 2 + for result in results[0]: + assert result.payload["category"] == "A" + + +def test_col_info(faiss_instance, mock_faiss_index): + # Mock index attributes + mock_faiss_index.ntotal = 5 + mock_faiss_index.d = 128 + + # Get collection info + info = faiss_instance.col_info() + + # Verify the returned info + assert info["name"] == "test_collection" + assert info["count"] == 5 + assert info["dimension"] == 128 + assert info["distance"] == "euclidean" + + +def test_delete_col(faiss_instance): + # Mock the os.remove function + with patch('os.remove') as mock_remove: + with patch('os.path.exists', return_value=True): + # Call delete_col + faiss_instance.delete_col() + + # Verify os.remove was called twice (for index and docstore files) + assert mock_remove.call_count == 2 + + # Verify the internal state was reset + assert faiss_instance.index is None + assert faiss_instance.docstore == {} + assert faiss_instance.index_to_id == {} + + +def test_normalize_L2(faiss_instance, mock_faiss_index): + # Setup a FAISS instance with normalize_L2=True + faiss_instance.normalize_L2 = True + + # Prepare test data + vectors = [[0.1, 0.2, 0.3]] + + # Mock numpy array conversion + with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array: + # Mock faiss.normalize_L2 + with patch('faiss.normalize_L2') as mock_normalize: + # Call insert + faiss_instance.insert(vectors=vectors, ids=["id1"]) + + # Verify faiss.normalize_L2 was called + mock_normalize.assert_called_once()