Add Faiss Support (#2461)
This commit is contained in:
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
|||||||
install_all:
|
install_all:
|
||||||
poetry install
|
poetry install
|
||||||
poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||||
google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text
|
google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu
|
||||||
|
|
||||||
# Format code with ruff
|
# Format code with ruff
|
||||||
format:
|
format:
|
||||||
|
|||||||
72
docs/components/vectordbs/dbs/faiss.mdx
Normal file
72
docs/components/vectordbs/dbs/faiss.mdx
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
[FAISS](https://github.com/facebookresearch/faiss) is a library for efficient similarity search and clustering of dense vectors. It is designed to work with large-scale datasets and provides a high-performance search engine for vector data. FAISS is optimized for memory usage and search speed, making it an excellent choice for production environments.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "faiss",
|
||||||
|
"config": {
|
||||||
|
"collection_name": "test",
|
||||||
|
"path": "/tmp/faiss_memories",
|
||||||
|
"distance_strategy": "euclidean"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
|
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||||
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
|
]
|
||||||
|
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
To use FAISS in your mem0 project, you need to install the appropriate FAISS package for your environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For CPU version
|
||||||
|
pip install faiss-cpu
|
||||||
|
|
||||||
|
# For GPU version (requires CUDA)
|
||||||
|
pip install faiss-gpu
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Here are the parameters available for configuring FAISS:
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `collection_name` | The name of the collection | `mem0` |
|
||||||
|
| `path` | Path to store FAISS index and metadata | `/tmp/faiss/<collection_name>` |
|
||||||
|
| `distance_strategy` | Distance metric strategy to use (options: 'euclidean', 'inner_product', 'cosine') | `euclidean` |
|
||||||
|
| `normalize_L2` | Whether to normalize L2 vectors (only applicable for euclidean distance) | `False` |
|
||||||
|
|
||||||
|
### Performance Considerations
|
||||||
|
|
||||||
|
FAISS offers several advantages for vector search:
|
||||||
|
|
||||||
|
1. **Efficiency**: FAISS is optimized for memory usage and speed, making it suitable for large-scale applications.
|
||||||
|
2. **Offline Support**: FAISS works entirely locally, with no need for external servers or API calls.
|
||||||
|
3. **Storage Options**: Vectors can be stored in-memory for maximum speed or persisted to disk.
|
||||||
|
4. **Multiple Index Types**: FAISS supports different index types optimized for various use cases (though mem0 currently uses the basic flat index).
|
||||||
|
|
||||||
|
### Distance Strategies
|
||||||
|
|
||||||
|
FAISS in mem0 supports three distance strategies:
|
||||||
|
|
||||||
|
- **euclidean**: L2 distance, suitable for most embedding models
|
||||||
|
- **inner_product**: Dot product similarity, useful for some specialized embeddings
|
||||||
|
- **cosine**: Cosine similarity, best for comparing semantic similarity regardless of vector magnitude
|
||||||
|
|
||||||
|
When using `cosine` or `inner_product` with normalized vectors, you may want to set `normalize_L2=True` for better results.
|
||||||
@@ -27,6 +27,7 @@ See the list of supported vector databases below.
|
|||||||
<Card title="Supabase" href="/components/vectordbs/dbs/supabase"></Card>
|
<Card title="Supabase" href="/components/vectordbs/dbs/supabase"></Card>
|
||||||
<Card title="Vertex AI Vector Search" href="/components/vectordbs/dbs/vertex_ai_vector_search"></Card>
|
<Card title="Vertex AI Vector Search" href="/components/vectordbs/dbs/vertex_ai_vector_search"></Card>
|
||||||
<Card title="Weaviate" href="/components/vectordbs/dbs/weaviate"></Card>
|
<Card title="Weaviate" href="/components/vectordbs/dbs/weaviate"></Card>
|
||||||
|
<Card title="FAISS" href="/components/vectordbs/dbs/faiss"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
38
mem0/configs/vector_stores/faiss.py
Normal file
38
mem0/configs/vector_stores/faiss.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class FAISSConfig(BaseModel):
|
||||||
|
collection_name: str = Field("mem0", description="Default name for the collection")
|
||||||
|
path: Optional[str] = Field(None, description="Path to store FAISS index and metadata")
|
||||||
|
distance_strategy: str = Field(
|
||||||
|
"euclidean", description="Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'"
|
||||||
|
)
|
||||||
|
normalize_L2: bool = Field(
|
||||||
|
False, description="Whether to normalize L2 vectors (only applicable for euclidean distance)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_distance_strategy(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
distance_strategy = values.get("distance_strategy")
|
||||||
|
if distance_strategy and distance_strategy not in ["euclidean", "inner_product", "cosine"]:
|
||||||
|
raise ValueError("Invalid distance_strategy. Must be one of: 'euclidean', 'inner_product', 'cosine'")
|
||||||
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
@@ -76,6 +76,7 @@ class VectorStoreFactory:
|
|||||||
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
|
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
|
||||||
"supabase": "mem0.vector_stores.supabase.Supabase",
|
"supabase": "mem0.vector_stores.supabase.Supabase",
|
||||||
"weaviate": "mem0.vector_stores.weaviate.Weaviate",
|
"weaviate": "mem0.vector_stores.weaviate.Weaviate",
|
||||||
|
"faiss": "mem0.vector_stores.faiss.FAISS",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"opensearch": "OpenSearchConfig",
|
"opensearch": "OpenSearchConfig",
|
||||||
"supabase": "SupabaseConfig",
|
"supabase": "SupabaseConfig",
|
||||||
"weaviate": "WeaviateConfig",
|
"weaviate": "WeaviateConfig",
|
||||||
|
"faiss": "FAISSConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
464
mem0/vector_stores/faiss.py
Normal file
464
mem0/vector_stores/faiss.py
Normal file
@@ -0,0 +1,464 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
import faiss
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import faiss python package. "
|
||||||
|
"Please install it with `pip install faiss-gpu` (for CUDA supported GPU) "
|
||||||
|
"or `pip install faiss-cpu` (depending on Python version)."
|
||||||
|
)
|
||||||
|
|
||||||
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputData(BaseModel):
|
||||||
|
id: Optional[str] # memory id
|
||||||
|
score: Optional[float] # distance
|
||||||
|
payload: Optional[Dict] # metadata
|
||||||
|
|
||||||
|
|
||||||
|
class FAISS(VectorStoreBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
distance_strategy: str = "euclidean",
|
||||||
|
normalize_L2: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the FAISS vector store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name (str): Name of the collection.
|
||||||
|
path (str, optional): Path for local FAISS database. Defaults to None.
|
||||||
|
distance_strategy (str, optional): Distance strategy to use. Options: 'euclidean', 'inner_product', 'cosine'.
|
||||||
|
Defaults to "euclidean".
|
||||||
|
normalize_L2 (bool, optional): Whether to normalize L2 vectors. Only applicable for euclidean distance.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.path = path or f"/tmp/faiss/{collection_name}"
|
||||||
|
self.distance_strategy = distance_strategy
|
||||||
|
self.normalize_L2 = normalize_L2
|
||||||
|
|
||||||
|
# Initialize storage structures
|
||||||
|
self.index = None
|
||||||
|
self.docstore = {}
|
||||||
|
self.index_to_id = {}
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
if self.path:
|
||||||
|
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||||
|
|
||||||
|
# Try to load existing index if available
|
||||||
|
index_path = f"{self.path}/{collection_name}.faiss"
|
||||||
|
docstore_path = f"{self.path}/{collection_name}.pkl"
|
||||||
|
if os.path.exists(index_path) and os.path.exists(docstore_path):
|
||||||
|
self._load(index_path, docstore_path)
|
||||||
|
else:
|
||||||
|
self.create_col(collection_name)
|
||||||
|
|
||||||
|
def _load(self, index_path: str, docstore_path: str):
|
||||||
|
"""
|
||||||
|
Load FAISS index and docstore from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path (str): Path to FAISS index file.
|
||||||
|
docstore_path (str): Path to docstore pickle file.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.index = faiss.read_index(index_path)
|
||||||
|
with open(docstore_path, "rb") as f:
|
||||||
|
self.docstore, self.index_to_id = pickle.load(f)
|
||||||
|
logger.info(f"Loaded FAISS index from {index_path} with {self.index.ntotal} vectors")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load FAISS index: {e}")
|
||||||
|
|
||||||
|
self.docstore = {}
|
||||||
|
self.index_to_id = {}
|
||||||
|
|
||||||
|
def _save(self):
|
||||||
|
"""Save FAISS index and docstore to disk."""
|
||||||
|
if not self.path or not self.index:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.makedirs(self.path, exist_ok=True)
|
||||||
|
index_path = f"{self.path}/{self.collection_name}.faiss"
|
||||||
|
docstore_path = f"{self.path}/{self.collection_name}.pkl"
|
||||||
|
|
||||||
|
faiss.write_index(self.index, index_path)
|
||||||
|
with open(docstore_path, "wb") as f:
|
||||||
|
pickle.dump((self.docstore, self.index_to_id), f)
|
||||||
|
logger.info(f"Saved FAISS index to {index_path} with {self.index.ntotal} vectors")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save FAISS index: {e}")
|
||||||
|
|
||||||
|
def _parse_output(self, scores, ids, limit=None) -> List[OutputData]:
|
||||||
|
"""
|
||||||
|
Parse the output data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scores: Similarity scores from FAISS.
|
||||||
|
ids: Indices from FAISS.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: Parsed output data.
|
||||||
|
"""
|
||||||
|
if limit is None:
|
||||||
|
limit = len(ids)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(min(len(ids), limit)):
|
||||||
|
if ids[i] == -1: # FAISS returns -1 for empty results
|
||||||
|
continue
|
||||||
|
|
||||||
|
index_id = int(ids[i])
|
||||||
|
vector_id = self.index_to_id.get(index_id)
|
||||||
|
if vector_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = self.docstore.get(vector_id)
|
||||||
|
if payload is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload_copy = payload.copy()
|
||||||
|
|
||||||
|
score = float(scores[i])
|
||||||
|
entry = OutputData(
|
||||||
|
id=vector_id,
|
||||||
|
score=score,
|
||||||
|
payload=payload_copy,
|
||||||
|
)
|
||||||
|
results.append(entry)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def create_col(self, name: str, vector_size: int = 1536, distance: str = None):
|
||||||
|
"""
|
||||||
|
Create a new collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): Name of the collection.
|
||||||
|
vector_size (int, optional): Dimensionality of vectors. Defaults to 1536.
|
||||||
|
distance (str, optional): Distance metric to use. Overrides the distance_strategy
|
||||||
|
passed during initialization. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self: The FAISS instance.
|
||||||
|
"""
|
||||||
|
distance_strategy = distance or self.distance_strategy
|
||||||
|
|
||||||
|
# Create index based on distance strategy
|
||||||
|
if distance_strategy.lower() == "inner_product" or distance_strategy.lower() == "cosine":
|
||||||
|
self.index = faiss.IndexFlatIP(vector_size)
|
||||||
|
else:
|
||||||
|
self.index = faiss.IndexFlatL2(vector_size)
|
||||||
|
|
||||||
|
self.collection_name = name
|
||||||
|
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self,
|
||||||
|
vectors: List[list],
|
||||||
|
payloads: Optional[List[Dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Insert vectors into a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors (List[list]): List of vectors to insert.
|
||||||
|
payloads (Optional[List[Dict]], optional): List of payloads corresponding to vectors. Defaults to None.
|
||||||
|
ids (Optional[List[str]], optional): List of IDs corresponding to vectors. Defaults to None.
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
raise ValueError("Collection not initialized. Call create_col first.")
|
||||||
|
|
||||||
|
if ids is None:
|
||||||
|
ids = [str(uuid.uuid4()) for _ in range(len(vectors))]
|
||||||
|
|
||||||
|
if payloads is None:
|
||||||
|
payloads = [{} for _ in range(len(vectors))]
|
||||||
|
|
||||||
|
if len(vectors) != len(ids) or len(vectors) != len(payloads):
|
||||||
|
raise ValueError("Vectors, payloads, and IDs must have the same length")
|
||||||
|
|
||||||
|
vectors_np = np.array(vectors, dtype=np.float32)
|
||||||
|
|
||||||
|
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
|
||||||
|
faiss.normalize_L2(vectors_np)
|
||||||
|
|
||||||
|
self.index.add(vectors_np)
|
||||||
|
|
||||||
|
starting_idx = len(self.index_to_id)
|
||||||
|
for i, (vector_id, payload) in enumerate(zip(ids, payloads)):
|
||||||
|
self.docstore[vector_id] = payload.copy()
|
||||||
|
self.index_to_id[starting_idx + i] = vector_id
|
||||||
|
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
logger.info(f"Inserted {len(vectors)} vectors into collection {self.collection_name}")
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
|
||||||
|
) -> List[OutputData]:
|
||||||
|
"""
|
||||||
|
Search for similar vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query (not used, kept for API compatibility).
|
||||||
|
vectors (List[list]): List of vectors to search.
|
||||||
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
|
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: Search results.
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
raise ValueError("Collection not initialized. Call create_col first.")
|
||||||
|
|
||||||
|
query_vectors = np.array(vectors, dtype=np.float32)
|
||||||
|
|
||||||
|
if len(query_vectors.shape) == 1:
|
||||||
|
query_vectors = query_vectors.reshape(1, -1)
|
||||||
|
|
||||||
|
if self.normalize_L2 and self.distance_strategy.lower() == "euclidean":
|
||||||
|
faiss.normalize_L2(query_vectors)
|
||||||
|
|
||||||
|
fetch_k = limit * 2 if filters else limit
|
||||||
|
scores, indices = self.index.search(query_vectors, fetch_k)
|
||||||
|
|
||||||
|
results = self._parse_output(scores[0], indices[0], limit)
|
||||||
|
|
||||||
|
if filters:
|
||||||
|
filtered_results = []
|
||||||
|
for result in results:
|
||||||
|
if self._apply_filters(result.payload, filters):
|
||||||
|
filtered_results.append(result)
|
||||||
|
if len(filtered_results) >= limit:
|
||||||
|
break
|
||||||
|
results = filtered_results[:limit]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _apply_filters(self, payload: Dict, filters: Dict) -> bool:
|
||||||
|
"""
|
||||||
|
Apply filters to a payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload (Dict): Payload to filter.
|
||||||
|
filters (Dict): Filters to apply.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if payload passes filters, False otherwise.
|
||||||
|
"""
|
||||||
|
if not filters or not payload:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for key, value in filters.items():
|
||||||
|
if key not in payload:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
if payload[key] not in value:
|
||||||
|
return False
|
||||||
|
elif payload[key] != value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete(self, vector_id: str):
|
||||||
|
"""
|
||||||
|
Delete a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to delete.
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
raise ValueError("Collection not initialized. Call create_col first.")
|
||||||
|
|
||||||
|
index_to_delete = None
|
||||||
|
for idx, vid in self.index_to_id.items():
|
||||||
|
if vid == vector_id:
|
||||||
|
index_to_delete = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if index_to_delete is not None:
|
||||||
|
self.docstore.pop(vector_id, None)
|
||||||
|
self.index_to_id.pop(index_to_delete, None)
|
||||||
|
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
logger.info(f"Deleted vector {vector_id} from collection {self.collection_name}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Vector {vector_id} not found in collection {self.collection_name}")
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
vector_id: str,
|
||||||
|
vector: Optional[List[float]] = None,
|
||||||
|
payload: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a vector and its payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to update.
|
||||||
|
vector (Optional[List[float]], optional): Updated vector. Defaults to None.
|
||||||
|
payload (Optional[Dict], optional): Updated payload. Defaults to None.
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
raise ValueError("Collection not initialized. Call create_col first.")
|
||||||
|
|
||||||
|
if vector_id not in self.docstore:
|
||||||
|
raise ValueError(f"Vector {vector_id} not found")
|
||||||
|
|
||||||
|
current_payload = self.docstore[vector_id].copy()
|
||||||
|
|
||||||
|
if payload is not None:
|
||||||
|
self.docstore[vector_id] = payload.copy()
|
||||||
|
current_payload = self.docstore[vector_id].copy()
|
||||||
|
|
||||||
|
if vector is not None:
|
||||||
|
self.delete(vector_id)
|
||||||
|
self.insert([vector], [current_payload], [vector_id])
|
||||||
|
else:
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
logger.info(f"Updated vector {vector_id} in collection {self.collection_name}")
|
||||||
|
|
||||||
|
def get(self, vector_id: str) -> OutputData:
|
||||||
|
"""
|
||||||
|
Retrieve a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OutputData: Retrieved vector.
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
raise ValueError("Collection not initialized. Call create_col first.")
|
||||||
|
|
||||||
|
if vector_id not in self.docstore:
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload = self.docstore[vector_id].copy()
|
||||||
|
|
||||||
|
return OutputData(
|
||||||
|
id=vector_id,
|
||||||
|
score=None,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_cols(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
List all collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of collection names.
|
||||||
|
"""
|
||||||
|
if not self.path:
|
||||||
|
return [self.collection_name] if self.index else []
|
||||||
|
|
||||||
|
try:
|
||||||
|
collections = []
|
||||||
|
path = Path(self.path).parent
|
||||||
|
for file in path.glob("*.faiss"):
|
||||||
|
collections.append(file.stem)
|
||||||
|
return collections
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to list collections: {e}")
|
||||||
|
return [self.collection_name] if self.index else []
|
||||||
|
|
||||||
|
def delete_col(self):
|
||||||
|
"""
|
||||||
|
Delete a collection.
|
||||||
|
"""
|
||||||
|
if self.path:
|
||||||
|
try:
|
||||||
|
index_path = f"{self.path}/{self.collection_name}.faiss"
|
||||||
|
docstore_path = f"{self.path}/{self.collection_name}.pkl"
|
||||||
|
|
||||||
|
if os.path.exists(index_path):
|
||||||
|
os.remove(index_path)
|
||||||
|
if os.path.exists(docstore_path):
|
||||||
|
os.remove(docstore_path)
|
||||||
|
|
||||||
|
logger.info(f"Deleted collection {self.collection_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete collection: {e}")
|
||||||
|
|
||||||
|
self.index = None
|
||||||
|
self.docstore = {}
|
||||||
|
self.index_to_id = {}
|
||||||
|
|
||||||
|
def col_info(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Get information about a collection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Collection information.
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
return {"name": self.collection_name, "count": 0}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": self.collection_name,
|
||||||
|
"count": self.index.ntotal,
|
||||||
|
"dimension": self.index.d,
|
||||||
|
"distance": self.distance_strategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||||
|
"""
|
||||||
|
List all vectors in a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (Optional[Dict], optional): Filters to apply to the list. Defaults to None.
|
||||||
|
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: List of vectors.
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for vector_id, payload in self.docstore.items():
|
||||||
|
if filters and not self._apply_filters(payload, filters):
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload_copy = payload.copy()
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
OutputData(
|
||||||
|
id=vector_id,
|
||||||
|
score=None,
|
||||||
|
payload=payload_copy,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
if count >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
return [results]
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "mem0ai"
|
name = "mem0ai"
|
||||||
version = "0.1.77"
|
version = "0.1.78"
|
||||||
description = "Long-term memory for AI Agents"
|
description = "Long-term memory for AI Agents"
|
||||||
authors = ["Mem0 <founders@mem0.ai>"]
|
authors = ["Mem0 <founders@mem0.ai>"]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
|||||||
314
tests/vector_stores/test_faiss.py
Normal file
314
tests/vector_stores/test_faiss.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mem0.vector_stores.faiss import FAISS, OutputData
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_faiss_index():
|
||||||
|
index = Mock(spec=faiss.IndexFlatL2)
|
||||||
|
index.d = 128 # Dimension of the vectors
|
||||||
|
index.ntotal = 0 # Number of vectors in the index
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def faiss_instance(mock_faiss_index):
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Mock the faiss index creation
|
||||||
|
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index):
|
||||||
|
# Mock the faiss.write_index function
|
||||||
|
with patch('faiss.write_index'):
|
||||||
|
# Create a FAISS instance with a temporary directory
|
||||||
|
faiss_store = FAISS(
|
||||||
|
collection_name="test_collection",
|
||||||
|
path=os.path.join(temp_dir, "test_faiss"),
|
||||||
|
distance_strategy="euclidean",
|
||||||
|
)
|
||||||
|
# Set up the mock index
|
||||||
|
faiss_store.index = mock_faiss_index
|
||||||
|
yield faiss_store
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_col(faiss_instance, mock_faiss_index):
|
||||||
|
# Test creating a collection with euclidean distance
|
||||||
|
with patch('faiss.IndexFlatL2', return_value=mock_faiss_index) as mock_index_flat_l2:
|
||||||
|
with patch('faiss.write_index'):
|
||||||
|
faiss_instance.create_col(name="new_collection", vector_size=256)
|
||||||
|
mock_index_flat_l2.assert_called_once_with(256)
|
||||||
|
|
||||||
|
# Test creating a collection with inner product distance
|
||||||
|
with patch('faiss.IndexFlatIP', return_value=mock_faiss_index) as mock_index_flat_ip:
|
||||||
|
with patch('faiss.write_index'):
|
||||||
|
faiss_instance.create_col(name="new_collection", vector_size=256, distance="inner_product")
|
||||||
|
mock_index_flat_ip.assert_called_once_with(256)
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert(faiss_instance, mock_faiss_index):
|
||||||
|
# Prepare test data
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||||
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
|
# Mock the numpy array conversion
|
||||||
|
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
|
||||||
|
# Mock index.add
|
||||||
|
mock_faiss_index.add.return_value = None
|
||||||
|
|
||||||
|
# Call insert
|
||||||
|
faiss_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
|
# Verify numpy.array was called
|
||||||
|
mock_np_array.assert_called_once_with(vectors, dtype=np.float32)
|
||||||
|
|
||||||
|
# Verify index.add was called
|
||||||
|
mock_faiss_index.add.assert_called_once()
|
||||||
|
|
||||||
|
# Verify docstore and index_to_id were updated
|
||||||
|
assert faiss_instance.docstore["id1"] == {"name": "vector1"}
|
||||||
|
assert faiss_instance.docstore["id2"] == {"name": "vector2"}
|
||||||
|
assert faiss_instance.index_to_id[0] == "id1"
|
||||||
|
assert faiss_instance.index_to_id[1] == "id2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_search(faiss_instance, mock_faiss_index):
|
||||||
|
# Prepare test data
|
||||||
|
query_vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
# Setup the docstore and index_to_id mapping
|
||||||
|
faiss_instance.docstore = {
|
||||||
|
"id1": {"name": "vector1"},
|
||||||
|
"id2": {"name": "vector2"}
|
||||||
|
}
|
||||||
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
|
# First, create the mock for the search return values
|
||||||
|
search_scores = np.array([[0.9, 0.8]])
|
||||||
|
search_indices = np.array([[0, 1]])
|
||||||
|
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
||||||
|
|
||||||
|
# Then patch numpy.array only for the query vector conversion
|
||||||
|
with patch('numpy.array') as mock_np_array:
|
||||||
|
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
||||||
|
|
||||||
|
# Then patch _parse_output to return the expected results
|
||||||
|
expected_results = [
|
||||||
|
OutputData(id="id1", score=0.9, payload={"name": "vector1"}),
|
||||||
|
OutputData(id="id2", score=0.8, payload={"name": "vector2"})
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(faiss_instance, '_parse_output', return_value=expected_results):
|
||||||
|
# Call search
|
||||||
|
results = faiss_instance.search(query="test query", vectors=query_vector, limit=2)
|
||||||
|
|
||||||
|
# Verify numpy.array was called (but we don't check exact call arguments since it's complex)
|
||||||
|
assert mock_np_array.called
|
||||||
|
|
||||||
|
# Verify index.search was called
|
||||||
|
mock_faiss_index.search.assert_called_once()
|
||||||
|
|
||||||
|
# Verify results
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
assert results[0].score == 0.9
|
||||||
|
assert results[0].payload == {"name": "vector1"}
|
||||||
|
assert results[1].id == "id2"
|
||||||
|
assert results[1].score == 0.8
|
||||||
|
assert results[1].payload == {"name": "vector2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_with_filters(faiss_instance, mock_faiss_index):
|
||||||
|
# Prepare test data
|
||||||
|
query_vector = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
# Setup the docstore and index_to_id mapping
|
||||||
|
faiss_instance.docstore = {
|
||||||
|
"id1": {"name": "vector1", "category": "A"},
|
||||||
|
"id2": {"name": "vector2", "category": "B"}
|
||||||
|
}
|
||||||
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
|
# First set up the search return values
|
||||||
|
search_scores = np.array([[0.9, 0.8]])
|
||||||
|
search_indices = np.array([[0, 1]])
|
||||||
|
mock_faiss_index.search.return_value = (search_scores, search_indices)
|
||||||
|
|
||||||
|
# Patch numpy.array for query vector conversion
|
||||||
|
with patch('numpy.array') as mock_np_array:
|
||||||
|
mock_np_array.return_value = np.array(query_vector, dtype=np.float32)
|
||||||
|
|
||||||
|
# Directly mock the _parse_output method to return our expected values
|
||||||
|
# We're simulating that _parse_output filters to just the first result
|
||||||
|
all_results = [
|
||||||
|
OutputData(id="id1", score=0.9, payload={"name": "vector1", "category": "A"}),
|
||||||
|
OutputData(id="id2", score=0.8, payload={"name": "vector2", "category": "B"})
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_results = [all_results[0]] # Just the "category": "A" result
|
||||||
|
|
||||||
|
# Create a side_effect function that returns all results first (for _parse_output)
|
||||||
|
# then returns filtered results (for the filters)
|
||||||
|
parse_output_mock = Mock(side_effect=[all_results, filtered_results])
|
||||||
|
|
||||||
|
# Replace the _apply_filters method to handle our test case
|
||||||
|
with patch.object(faiss_instance, '_parse_output', return_value=all_results):
|
||||||
|
with patch.object(faiss_instance, '_apply_filters', side_effect=lambda p, f: p.get("category") == "A"):
|
||||||
|
# Call search with filters
|
||||||
|
results = faiss_instance.search(
|
||||||
|
query="test query",
|
||||||
|
vectors=query_vector,
|
||||||
|
limit=2,
|
||||||
|
filters={"category": "A"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify numpy.array was called
|
||||||
|
assert mock_np_array.called
|
||||||
|
|
||||||
|
# Verify index.search was called
|
||||||
|
mock_faiss_index.search.assert_called_once()
|
||||||
|
|
||||||
|
# Verify filtered results - since we've mocked everything,
|
||||||
|
# we should get just the result we want
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
assert results[0].score == 0.9
|
||||||
|
assert results[0].payload == {"name": "vector1", "category": "A"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete(faiss_instance):
|
||||||
|
# Setup the docstore and index_to_id mapping
|
||||||
|
faiss_instance.docstore = {
|
||||||
|
"id1": {"name": "vector1"},
|
||||||
|
"id2": {"name": "vector2"}
|
||||||
|
}
|
||||||
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
|
# Call delete
|
||||||
|
faiss_instance.delete(vector_id="id1")
|
||||||
|
|
||||||
|
# Verify the vector was removed from docstore and index_to_id
|
||||||
|
assert "id1" not in faiss_instance.docstore
|
||||||
|
assert 0 not in faiss_instance.index_to_id
|
||||||
|
assert "id2" in faiss_instance.docstore
|
||||||
|
assert 1 in faiss_instance.index_to_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_update(faiss_instance, mock_faiss_index):
|
||||||
|
# Setup the docstore and index_to_id mapping
|
||||||
|
faiss_instance.docstore = {
|
||||||
|
"id1": {"name": "vector1"},
|
||||||
|
"id2": {"name": "vector2"}
|
||||||
|
}
|
||||||
|
faiss_instance.index_to_id = {0: "id1", 1: "id2"}
|
||||||
|
|
||||||
|
# Test updating payload only
|
||||||
|
faiss_instance.update(vector_id="id1", payload={"name": "updated_vector1"})
|
||||||
|
assert faiss_instance.docstore["id1"] == {"name": "updated_vector1"}
|
||||||
|
|
||||||
|
# Test updating vector
|
||||||
|
# This requires mocking the delete and insert methods
|
||||||
|
with patch.object(faiss_instance, 'delete') as mock_delete:
|
||||||
|
with patch.object(faiss_instance, 'insert') as mock_insert:
|
||||||
|
new_vector = [0.7, 0.8, 0.9]
|
||||||
|
faiss_instance.update(vector_id="id2", vector=new_vector)
|
||||||
|
|
||||||
|
# Verify delete and insert were called
|
||||||
|
# Match the actual call signature (positional arg instead of keyword)
|
||||||
|
mock_delete.assert_called_once_with("id2")
|
||||||
|
mock_insert.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get(faiss_instance):
|
||||||
|
# Setup the docstore
|
||||||
|
faiss_instance.docstore = {
|
||||||
|
"id1": {"name": "vector1"},
|
||||||
|
"id2": {"name": "vector2"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test getting an existing vector
|
||||||
|
result = faiss_instance.get(vector_id="id1")
|
||||||
|
assert result.id == "id1"
|
||||||
|
assert result.payload == {"name": "vector1"}
|
||||||
|
assert result.score is None
|
||||||
|
|
||||||
|
# Test getting a non-existent vector
|
||||||
|
result = faiss_instance.get(vector_id="id3")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_list(faiss_instance):
|
||||||
|
# Setup the docstore
|
||||||
|
faiss_instance.docstore = {
|
||||||
|
"id1": {"name": "vector1", "category": "A"},
|
||||||
|
"id2": {"name": "vector2", "category": "B"},
|
||||||
|
"id3": {"name": "vector3", "category": "A"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test listing all vectors
|
||||||
|
results = faiss_instance.list()
|
||||||
|
# Fix the expected result - the list method returns a list of lists
|
||||||
|
assert len(results[0]) == 3
|
||||||
|
|
||||||
|
# Test listing with a limit
|
||||||
|
results = faiss_instance.list(limit=2)
|
||||||
|
assert len(results[0]) == 2
|
||||||
|
|
||||||
|
# Test listing with filters
|
||||||
|
results = faiss_instance.list(filters={"category": "A"})
|
||||||
|
assert len(results[0]) == 2
|
||||||
|
for result in results[0]:
|
||||||
|
assert result.payload["category"] == "A"
|
||||||
|
|
||||||
|
|
||||||
|
def test_col_info(faiss_instance, mock_faiss_index):
|
||||||
|
# Mock index attributes
|
||||||
|
mock_faiss_index.ntotal = 5
|
||||||
|
mock_faiss_index.d = 128
|
||||||
|
|
||||||
|
# Get collection info
|
||||||
|
info = faiss_instance.col_info()
|
||||||
|
|
||||||
|
# Verify the returned info
|
||||||
|
assert info["name"] == "test_collection"
|
||||||
|
assert info["count"] == 5
|
||||||
|
assert info["dimension"] == 128
|
||||||
|
assert info["distance"] == "euclidean"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_col(faiss_instance):
|
||||||
|
# Mock the os.remove function
|
||||||
|
with patch('os.remove') as mock_remove:
|
||||||
|
with patch('os.path.exists', return_value=True):
|
||||||
|
# Call delete_col
|
||||||
|
faiss_instance.delete_col()
|
||||||
|
|
||||||
|
# Verify os.remove was called twice (for index and docstore files)
|
||||||
|
assert mock_remove.call_count == 2
|
||||||
|
|
||||||
|
# Verify the internal state was reset
|
||||||
|
assert faiss_instance.index is None
|
||||||
|
assert faiss_instance.docstore == {}
|
||||||
|
assert faiss_instance.index_to_id == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_L2(faiss_instance, mock_faiss_index):
|
||||||
|
# Setup a FAISS instance with normalize_L2=True
|
||||||
|
faiss_instance.normalize_L2 = True
|
||||||
|
|
||||||
|
# Prepare test data
|
||||||
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
|
||||||
|
# Mock numpy array conversion
|
||||||
|
with patch('numpy.array', return_value=np.array(vectors, dtype=np.float32)) as mock_np_array:
|
||||||
|
# Mock faiss.normalize_L2
|
||||||
|
with patch('faiss.normalize_L2') as mock_normalize:
|
||||||
|
# Call insert
|
||||||
|
faiss_instance.insert(vectors=vectors, ids=["id1"])
|
||||||
|
|
||||||
|
# Verify faiss.normalize_L2 was called
|
||||||
|
mock_normalize.assert_called_once()
|
||||||
Reference in New Issue
Block a user