Integrate Supabase VectorDB (#2290)
This commit is contained in:
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
|||||||
virtualenvs-in-project: true
|
virtualenvs-in-project: true
|
||||||
- name: Load cached venv
|
- name: Load cached venv
|
||||||
id: cached-poetry-dependencies
|
id: cached-poetry-dependencies
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: .venv
|
path: .venv
|
||||||
key: venv-mem0-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
key: venv-mem0-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
@@ -83,7 +83,7 @@ jobs:
|
|||||||
virtualenvs-in-project: true
|
virtualenvs-in-project: true
|
||||||
- name: Load cached venv
|
- name: Load cached venv
|
||||||
id: cached-poetry-dependencies
|
id: cached-poetry-dependencies
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: .venv
|
path: .venv
|
||||||
key: venv-embedchain-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
key: venv-embedchain-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
|||||||
install_all:
|
install_all:
|
||||||
poetry install
|
poetry install
|
||||||
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \
|
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \
|
||||||
google-generativeai elasticsearch opensearch-py
|
google-generativeai elasticsearch opensearch-py vecs
|
||||||
|
|
||||||
# Format code with ruff
|
# Format code with ruff
|
||||||
format:
|
format:
|
||||||
|
|||||||
@@ -86,6 +86,9 @@ Here's a comprehensive list of all parameters that can be used across different
|
|||||||
| `url` | Full URL for the server |
|
| `url` | Full URL for the server |
|
||||||
| `api_key` | API key for the server |
|
| `api_key` | API key for the server |
|
||||||
| `on_disk` | Enable persistent storage |
|
| `on_disk` | Enable persistent storage |
|
||||||
|
| `connection_string` | PostgreSQL connection string (for Supabase/PGVector) |
|
||||||
|
| `index_method` | Vector index method (for Supabase) |
|
||||||
|
| `index_measure` | Distance measure for similarity search (for Supabase) |
|
||||||
</Tab>
|
</Tab>
|
||||||
<Tab title="TypeScript">
|
<Tab title="TypeScript">
|
||||||
| Parameter | Description |
|
| Parameter | Description |
|
||||||
|
|||||||
78
docs/components/vectordbs/dbs/supabase.mdx
Normal file
78
docs/components/vectordbs/dbs/supabase.mdx
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
[Supabase](https://supabase.com/) is an open-source Firebase alternative that provides a PostgreSQL database with pgvector extension for vector similarity search. It offers a powerful and scalable solution for storing and querying vector embeddings.
|
||||||
|
|
||||||
|
Create a [Supabase](https://supabase.com/dashboard/projects) account and project, then get your connection string from Project Settings > Database. See the [docs](https://supabase.github.io/vecs/hosting/) for details.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "supabase",
|
||||||
|
"config": {
|
||||||
|
"connection_string": "postgresql://user:password@host:port/database",
|
||||||
|
"collection_name": "memories",
|
||||||
|
"index_method": "hnsw", # Optional: defaults to "auto"
|
||||||
|
"index_measure": "cosine_distance" # Optional: defaults to "cosine_distance"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||||
|
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||||
|
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||||
|
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||||
|
]
|
||||||
|
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Here are the parameters available for configuring Supabase:
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `connection_string` | PostgreSQL connection string (required) | None |
|
||||||
|
| `collection_name` | Name for the vector collection | `mem0` |
|
||||||
|
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||||
|
| `index_method` | Vector index method to use | `auto` |
|
||||||
|
| `index_measure` | Distance measure for similarity search | `cosine_distance` |
|
||||||
|
|
||||||
|
### Index Methods
|
||||||
|
|
||||||
|
The following index methods are supported:
|
||||||
|
|
||||||
|
- `auto`: Automatically selects the best available index method
|
||||||
|
- `hnsw`: Hierarchical Navigable Small World graph index (faster search, more memory usage)
|
||||||
|
- `ivfflat`: Inverted File Flat index (good balance of speed and memory)
|
||||||
|
|
||||||
|
### Distance Measures
|
||||||
|
|
||||||
|
Available distance measures for similarity search:
|
||||||
|
|
||||||
|
- `cosine_distance`: Cosine similarity (recommended for most embedding models)
|
||||||
|
- `l2_distance`: Euclidean distance
|
||||||
|
- `l1_distance`: Manhattan distance
|
||||||
|
- `max_inner_product`: Maximum inner product similarity
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
|
||||||
|
1. **Index Method Selection**:
|
||||||
|
- Use `hnsw` for fastest search performance when memory is not a constraint
|
||||||
|
- Use `ivfflat` for a good balance of search speed and memory usage
|
||||||
|
- Use `auto` if unsure, it will select the best method based on your data
|
||||||
|
|
||||||
|
2. **Distance Measure Selection**:
|
||||||
|
- Use `cosine_distance` for most embedding models (OpenAI, Hugging Face, etc.)
|
||||||
|
- Use `max_inner_product` if your vectors are normalized
|
||||||
|
- Use `l2_distance` or `l1_distance` if working with raw feature vectors
|
||||||
|
|
||||||
|
3. **Connection String**:
|
||||||
|
- Always use environment variables for sensitive information in the connection string
|
||||||
|
- Format: `postgresql://user:password@host:port/database`
|
||||||
@@ -23,6 +23,7 @@ See the list of supported vector databases below.
|
|||||||
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
|
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
|
||||||
<Card title="Elasticsearch" href="/components/vectordbs/dbs/elasticsearch"></Card>
|
<Card title="Elasticsearch" href="/components/vectordbs/dbs/elasticsearch"></Card>
|
||||||
<Card title="OpenSearch" href="/components/vectordbs/dbs/opensearch"></Card>
|
<Card title="OpenSearch" href="/components/vectordbs/dbs/opensearch"></Card>
|
||||||
|
<Card title="Supabase" href="/components/vectordbs/dbs/supabase"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -128,7 +128,8 @@
|
|||||||
"components/vectordbs/dbs/azure_ai_search",
|
"components/vectordbs/dbs/azure_ai_search",
|
||||||
"components/vectordbs/dbs/redis",
|
"components/vectordbs/dbs/redis",
|
||||||
"components/vectordbs/dbs/elasticsearch",
|
"components/vectordbs/dbs/elasticsearch",
|
||||||
"components/vectordbs/dbs/opensearch"
|
"components/vectordbs/dbs/opensearch",
|
||||||
|
"components/vectordbs/dbs/supabase"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
44
mem0/configs/vector_stores/supabase.py
Normal file
44
mem0/configs/vector_stores/supabase.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class IndexMethod(str, Enum):
|
||||||
|
AUTO = "auto"
|
||||||
|
HNSW = "hnsw"
|
||||||
|
IVFFLAT = "ivfflat"
|
||||||
|
|
||||||
|
|
||||||
|
class IndexMeasure(str, Enum):
|
||||||
|
COSINE = "cosine_distance"
|
||||||
|
L2 = "l2_distance"
|
||||||
|
L1 = "l1_distance"
|
||||||
|
MAX_INNER_PRODUCT = "max_inner_product"
|
||||||
|
|
||||||
|
|
||||||
|
class SupabaseConfig(BaseModel):
|
||||||
|
connection_string: str = Field(..., description="PostgreSQL connection string")
|
||||||
|
collection_name: str = Field("mem0", description="Name for the vector collection")
|
||||||
|
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||||
|
index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
|
||||||
|
index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def check_connection_string(cls, values):
|
||||||
|
conn_str = values.get("connection_string")
|
||||||
|
if not conn_str or not conn_str.startswith("postgresql://"):
|
||||||
|
raise ValueError("A valid PostgreSQL connection string must be provided")
|
||||||
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
allowed_fields = set(cls.model_fields.keys())
|
||||||
|
input_fields = set(values.keys())
|
||||||
|
extra_fields = input_fields - allowed_fields
|
||||||
|
if extra_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
@@ -70,6 +70,7 @@ class VectorStoreFactory:
|
|||||||
"redis": "mem0.vector_stores.redis.RedisDB",
|
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||||
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
||||||
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
|
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
|
||||||
|
"supabase": "mem0.vector_stores.supabase.Supabase",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"redis": "RedisDBConfig",
|
"redis": "RedisDBConfig",
|
||||||
"elasticsearch": "ElasticsearchConfig",
|
"elasticsearch": "ElasticsearchConfig",
|
||||||
"opensearch": "OpenSearchConfig",
|
"opensearch": "OpenSearchConfig",
|
||||||
|
"supabase": "SupabaseConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
231
mem0/vector_stores/supabase.py
Normal file
231
mem0/vector_stores/supabase.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vecs
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.")
|
||||||
|
|
||||||
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
from mem0.configs.vector_stores.supabase import IndexMethod, IndexMeasure
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputData(BaseModel):
|
||||||
|
id: Optional[str]
|
||||||
|
score: Optional[float]
|
||||||
|
payload: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class Supabase(VectorStoreBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_string: str,
|
||||||
|
collection_name: str,
|
||||||
|
embedding_model_dims: int,
|
||||||
|
index_method: IndexMethod = IndexMethod.AUTO,
|
||||||
|
index_measure: IndexMeasure = IndexMeasure.COSINE,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Supabase vector store using vecs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection_string (str): PostgreSQL connection string
|
||||||
|
collection_name (str): Collection name
|
||||||
|
embedding_model_dims (int): Dimension of the embedding vector
|
||||||
|
index_method (IndexMethod): Index method to use. Defaults to AUTO.
|
||||||
|
index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE.
|
||||||
|
"""
|
||||||
|
self.db = vecs.create_client(connection_string)
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.embedding_model_dims = embedding_model_dims
|
||||||
|
self.index_method = index_method
|
||||||
|
self.index_measure = index_measure
|
||||||
|
|
||||||
|
collections = self.list_cols()
|
||||||
|
if collection_name not in collections:
|
||||||
|
self.create_col(embedding_model_dims)
|
||||||
|
|
||||||
|
def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Preprocess filters to be compatible with vecs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (Dict, optional): Filters to preprocess. Multiple filters will be
|
||||||
|
combined with AND logic.
|
||||||
|
"""
|
||||||
|
if filters is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(filters) == 1:
|
||||||
|
# For single filter, keep the simple format
|
||||||
|
key, value = next(iter(filters.items()))
|
||||||
|
return {key: {"$eq": value}}
|
||||||
|
|
||||||
|
# For multiple filters, use $and clause
|
||||||
|
return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]}
|
||||||
|
|
||||||
|
def create_col(self, embedding_model_dims: Optional[int] = None) -> None:
|
||||||
|
"""
|
||||||
|
Create a new collection with vector support.
|
||||||
|
Will also initialize vector search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_model_dims (int, optional): Dimension of the embedding vector.
|
||||||
|
If not provided, uses the dimension specified in initialization.
|
||||||
|
"""
|
||||||
|
dims = embedding_model_dims or self.embedding_model_dims
|
||||||
|
if not dims:
|
||||||
|
raise ValueError(
|
||||||
|
"embedding_model_dims must be provided either during initialization or when creating collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Creating new collection: {self.collection_name}")
|
||||||
|
try:
|
||||||
|
self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims)
|
||||||
|
self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value)
|
||||||
|
logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create collection: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Insert vectors into the collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors (List[List[float]]): List of vectors to insert
|
||||||
|
payloads (List[Dict], optional): List of payloads corresponding to vectors
|
||||||
|
ids (List[str], optional): List of IDs corresponding to vectors
|
||||||
|
"""
|
||||||
|
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
ids = [str(uuid.uuid4()) for _ in vectors]
|
||||||
|
if not payloads:
|
||||||
|
payloads = [{} for _ in vectors]
|
||||||
|
|
||||||
|
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
|
||||||
|
print(records)
|
||||||
|
|
||||||
|
self.collection.upsert(records)
|
||||||
|
|
||||||
|
def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]:
|
||||||
|
"""
|
||||||
|
Search for similar vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (List[float]): Query vector
|
||||||
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: Search results
|
||||||
|
"""
|
||||||
|
filters = self._preprocess_filters(filters)
|
||||||
|
print(filters)
|
||||||
|
results = self.collection.query(
|
||||||
|
data=query, limit=limit, filters=filters, include_metadata=True, include_value=True
|
||||||
|
)
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
|
||||||
|
|
||||||
|
def delete(self, vector_id: str):
|
||||||
|
"""
|
||||||
|
Delete a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to delete
|
||||||
|
"""
|
||||||
|
self.collection.delete([(vector_id,)])
|
||||||
|
|
||||||
|
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
Update a vector and/or its payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to update
|
||||||
|
vector (List[float], optional): Updated vector
|
||||||
|
payload (Dict, optional): Updated payload
|
||||||
|
"""
|
||||||
|
if vector is None:
|
||||||
|
# If only updating metadata, we need to get the existing vector
|
||||||
|
existing = self.get(vector_id)
|
||||||
|
if existing and existing.payload:
|
||||||
|
vector = existing.payload.get("vector", [])
|
||||||
|
|
||||||
|
if vector:
|
||||||
|
self.collection.upsert([(vector_id, vector, payload or {})])
|
||||||
|
|
||||||
|
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||||
|
"""
|
||||||
|
Retrieve a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (str): ID of the vector to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[OutputData]: Retrieved vector data or None if not found
|
||||||
|
"""
|
||||||
|
result = self.collection.fetch([(vector_id,)])
|
||||||
|
if not result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
record = result[0]
|
||||||
|
return OutputData(id=str(record.id), score=None, payload=record.metadata)
|
||||||
|
|
||||||
|
def list_cols(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
List all collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of collection names
|
||||||
|
"""
|
||||||
|
return self.db.list_collections()
|
||||||
|
|
||||||
|
def delete_col(self):
|
||||||
|
"""Delete the collection."""
|
||||||
|
self.db.delete_collection(self.collection_name)
|
||||||
|
|
||||||
|
def col_info(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about the collection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: Collection information including name and configuration
|
||||||
|
"""
|
||||||
|
info = self.collection.describe()
|
||||||
|
return {
|
||||||
|
"name": info.name,
|
||||||
|
"count": info.vectors,
|
||||||
|
"dimension": info.dimension,
|
||||||
|
"index": {"method": info.index_method, "metric": info.distance_metric},
|
||||||
|
}
|
||||||
|
|
||||||
|
def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]:
|
||||||
|
"""
|
||||||
|
List vectors in the collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filters (Dict, optional): Filters to apply
|
||||||
|
limit (int, optional): Maximum number of results to return. Defaults to 100.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: List of vectors
|
||||||
|
"""
|
||||||
|
filters = self._preprocess_filters(filters)
|
||||||
|
query = [0] * self.embedding_model_dims
|
||||||
|
ids = self.collection.query(
|
||||||
|
data=query, limit=limit, filters=filters, include_metadata=True, include_value=False
|
||||||
|
)
|
||||||
|
ids = [id[0] for id in ids]
|
||||||
|
records = self.collection.fetch(ids=ids)
|
||||||
|
|
||||||
|
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
||||||
178
tests/vector_stores/test_supabase.py
Normal file
178
tests/vector_stores/test_supabase.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod
|
||||||
|
from mem0.vector_stores.supabase import Supabase
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_vecs_client():
|
||||||
|
with patch("vecs.create_client") as mock_client:
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_collection():
|
||||||
|
collection = Mock()
|
||||||
|
collection.name = "test_collection"
|
||||||
|
collection.vectors = 100
|
||||||
|
collection.dimension = 1536
|
||||||
|
collection.index_method = "hnsw"
|
||||||
|
collection.distance_metric = "cosine_distance"
|
||||||
|
collection.describe.return_value = collection
|
||||||
|
return collection
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def supabase_instance(mock_vecs_client, mock_collection):
|
||||||
|
# Set up the mock client to return our mock collection
|
||||||
|
mock_vecs_client.return_value.get_or_create_collection.return_value = mock_collection
|
||||||
|
mock_vecs_client.return_value.list_collections.return_value = ["test_collection"]
|
||||||
|
|
||||||
|
instance = Supabase(
|
||||||
|
connection_string="postgresql://user:password@localhost:5432/test",
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding_model_dims=1536,
|
||||||
|
index_method=IndexMethod.HNSW,
|
||||||
|
index_measure=IndexMeasure.COSINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually set the collection attribute since we're mocking the initialization
|
||||||
|
instance.collection = mock_collection
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
|
||||||
|
supabase_instance.create_col(1536)
|
||||||
|
|
||||||
|
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(
|
||||||
|
name="test_collection",
|
||||||
|
dimension=1536
|
||||||
|
)
|
||||||
|
mock_collection.create_index.assert_called_with(
|
||||||
|
method="hnsw",
|
||||||
|
measure="cosine_distance"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_vectors(supabase_instance, mock_collection):
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||||
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
|
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
|
expected_records = [
|
||||||
|
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
||||||
|
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
||||||
|
]
|
||||||
|
mock_collection.upsert.assert_called_once_with(expected_records)
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_vectors(supabase_instance, mock_collection):
|
||||||
|
mock_results = [
|
||||||
|
("id1", 0.9, {"name": "vector1"}),
|
||||||
|
("id2", 0.8, {"name": "vector2"})
|
||||||
|
]
|
||||||
|
mock_collection.query.return_value = mock_results
|
||||||
|
|
||||||
|
query = [0.1, 0.2, 0.3]
|
||||||
|
filters = {"category": "test"}
|
||||||
|
results = supabase_instance.search(query=query, limit=2, filters=filters)
|
||||||
|
|
||||||
|
mock_collection.query.assert_called_once_with(
|
||||||
|
data=query,
|
||||||
|
limit=2,
|
||||||
|
filters={"category": {"$eq": "test"}},
|
||||||
|
include_metadata=True,
|
||||||
|
include_value=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
assert results[0].score == 0.9
|
||||||
|
assert results[0].payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_vector(supabase_instance, mock_collection):
|
||||||
|
vector_id = "id1"
|
||||||
|
supabase_instance.delete(vector_id=vector_id)
|
||||||
|
mock_collection.delete.assert_called_once_with([("id1",)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_vector(supabase_instance, mock_collection):
|
||||||
|
vector_id = "id1"
|
||||||
|
new_vector = [0.7, 0.8, 0.9]
|
||||||
|
new_payload = {"name": "updated_vector"}
|
||||||
|
|
||||||
|
supabase_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
|
||||||
|
mock_collection.upsert.assert_called_once_with([("id1", new_vector, new_payload)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_vector(supabase_instance, mock_collection):
|
||||||
|
# Create a Mock object to represent the record
|
||||||
|
mock_record = Mock()
|
||||||
|
mock_record.id = "id1"
|
||||||
|
mock_record.metadata = {"name": "vector1"}
|
||||||
|
mock_record.values = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
# Set the fetch return value to a list containing our mock record
|
||||||
|
mock_collection.fetch.return_value = [mock_record]
|
||||||
|
|
||||||
|
result = supabase_instance.get(vector_id="id1")
|
||||||
|
|
||||||
|
mock_collection.fetch.assert_called_once_with([("id1",)])
|
||||||
|
assert result.id == "id1"
|
||||||
|
assert result.payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_vectors(supabase_instance, mock_collection):
|
||||||
|
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
|
||||||
|
mock_fetch_results = [
|
||||||
|
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
||||||
|
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_collection.query.return_value = mock_query_results
|
||||||
|
mock_collection.fetch.return_value = mock_fetch_results
|
||||||
|
|
||||||
|
results = supabase_instance.list(limit=2, filters={"category": "test"})
|
||||||
|
|
||||||
|
assert len(results[0]) == 2
|
||||||
|
assert results[0][0].id == "id1"
|
||||||
|
assert results[0][0].payload == {"name": "vector1"}
|
||||||
|
assert results[0][1].id == "id2"
|
||||||
|
assert results[0][1].payload == {"name": "vector2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_col_info(supabase_instance, mock_collection):
|
||||||
|
info = supabase_instance.col_info()
|
||||||
|
|
||||||
|
assert info == {
|
||||||
|
"name": "test_collection",
|
||||||
|
"count": 100,
|
||||||
|
"dimension": 1536,
|
||||||
|
"index": {
|
||||||
|
"method": "hnsw",
|
||||||
|
"metric": "cosine_distance"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_filters(supabase_instance):
|
||||||
|
# Test single filter
|
||||||
|
single_filter = {"category": "test"}
|
||||||
|
assert supabase_instance._preprocess_filters(single_filter) == {"category": {"$eq": "test"}}
|
||||||
|
|
||||||
|
# Test multiple filters
|
||||||
|
multi_filter = {"category": "test", "type": "document"}
|
||||||
|
assert supabase_instance._preprocess_filters(multi_filter) == {
|
||||||
|
"$and": [
|
||||||
|
{"category": {"$eq": "test"}},
|
||||||
|
{"type": {"$eq": "document"}}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test None filters
|
||||||
|
assert supabase_instance._preprocess_filters(None) is None
|
||||||
Reference in New Issue
Block a user