Integrate Supabase VectorDB (#2290)
This commit is contained in:
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
virtualenvs-in-project: true
|
||||
- name: Load cached venv
|
||||
id: cached-poetry-dependencies
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-mem0-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
virtualenvs-in-project: true
|
||||
- name: Load cached venv
|
||||
id: cached-poetry-dependencies
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: .venv
|
||||
key: venv-embedchain-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
|
||||
|
||||
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
||||
install_all:
|
||||
poetry install
|
||||
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \
|
||||
google-generativeai elasticsearch opensearch-py
|
||||
google-generativeai elasticsearch opensearch-py vecs
|
||||
|
||||
# Format code with ruff
|
||||
format:
|
||||
|
||||
@@ -86,6 +86,9 @@ Here's a comprehensive list of all parameters that can be used across different
|
||||
| `url` | Full URL for the server |
|
||||
| `api_key` | API key for the server |
|
||||
| `on_disk` | Enable persistent storage |
|
||||
| `connection_string` | PostgreSQL connection string (for Supabase/PGVector) |
|
||||
| `index_method` | Vector index method (for Supabase) |
|
||||
| `index_measure` | Distance measure for similarity search (for Supabase) |
|
||||
</Tab>
|
||||
<Tab title="TypeScript">
|
||||
| Parameter | Description |
|
||||
|
||||
78
docs/components/vectordbs/dbs/supabase.mdx
Normal file
78
docs/components/vectordbs/dbs/supabase.mdx
Normal file
@@ -0,0 +1,78 @@
|
||||
[Supabase](https://supabase.com/) is an open-source Firebase alternative that provides a PostgreSQL database with pgvector extension for vector similarity search. It offers a powerful and scalable solution for storing and querying vector embeddings.
|
||||
|
||||
Create a [Supabase](https://supabase.com/dashboard/projects) account and project, then get your connection string from Project Settings > Database. See the [docs](https://supabase.github.io/vecs/hosting/) for details.
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "supabase",
|
||||
"config": {
|
||||
"connection_string": "postgresql://user:password@host:port/database",
|
||||
"collection_name": "memories",
|
||||
"index_method": "hnsw", # Optional: defaults to "auto"
|
||||
"index_measure": "cosine_distance" # Optional: defaults to "cosine_distance"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
|
||||
### Config
|
||||
|
||||
Here are the parameters available for configuring Supabase:
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `connection_string` | PostgreSQL connection string (required) | None |
|
||||
| `collection_name` | Name for the vector collection | `mem0` |
|
||||
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
|
||||
| `index_method` | Vector index method to use | `auto` |
|
||||
| `index_measure` | Distance measure for similarity search | `cosine_distance` |
|
||||
|
||||
### Index Methods
|
||||
|
||||
The following index methods are supported:
|
||||
|
||||
- `auto`: Automatically selects the best available index method
|
||||
- `hnsw`: Hierarchical Navigable Small World graph index (faster search, more memory usage)
|
||||
- `ivfflat`: Inverted File Flat index (good balance of speed and memory)
|
||||
|
||||
### Distance Measures
|
||||
|
||||
Available distance measures for similarity search:
|
||||
|
||||
- `cosine_distance`: Cosine similarity (recommended for most embedding models)
|
||||
- `l2_distance`: Euclidean distance
|
||||
- `l1_distance`: Manhattan distance
|
||||
- `max_inner_product`: Maximum inner product similarity
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Index Method Selection**:
|
||||
- Use `hnsw` for fastest search performance when memory is not a constraint
|
||||
- Use `ivfflat` for a good balance of search speed and memory usage
|
||||
- Use `auto` if unsure, it will select the best method based on your data
|
||||
|
||||
2. **Distance Measure Selection**:
|
||||
- Use `cosine_distance` for most embedding models (OpenAI, Hugging Face, etc.)
|
||||
- Use `max_inner_product` if your vectors are normalized
|
||||
- Use `l2_distance` or `l1_distance` if working with raw feature vectors
|
||||
|
||||
3. **Connection String**:
|
||||
- Always use environment variables for sensitive information in the connection string
|
||||
- Format: `postgresql://user:password@host:port/database`
|
||||
@@ -23,6 +23,7 @@ See the list of supported vector databases below.
|
||||
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
|
||||
<Card title="Elasticsearch" href="/components/vectordbs/dbs/elasticsearch"></Card>
|
||||
<Card title="OpenSearch" href="/components/vectordbs/dbs/opensearch"></Card>
|
||||
<Card title="Supabase" href="/components/vectordbs/dbs/supabase"></Card>
|
||||
</CardGroup>
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -128,7 +128,8 @@
|
||||
"components/vectordbs/dbs/azure_ai_search",
|
||||
"components/vectordbs/dbs/redis",
|
||||
"components/vectordbs/dbs/elasticsearch",
|
||||
"components/vectordbs/dbs/opensearch"
|
||||
"components/vectordbs/dbs/opensearch",
|
||||
"components/vectordbs/dbs/supabase"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
44
mem0/configs/vector_stores/supabase.py
Normal file
44
mem0/configs/vector_stores/supabase.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class IndexMethod(str, Enum):
|
||||
AUTO = "auto"
|
||||
HNSW = "hnsw"
|
||||
IVFFLAT = "ivfflat"
|
||||
|
||||
|
||||
class IndexMeasure(str, Enum):
|
||||
COSINE = "cosine_distance"
|
||||
L2 = "l2_distance"
|
||||
L1 = "l1_distance"
|
||||
MAX_INNER_PRODUCT = "max_inner_product"
|
||||
|
||||
|
||||
class SupabaseConfig(BaseModel):
|
||||
connection_string: str = Field(..., description="PostgreSQL connection string")
|
||||
collection_name: str = Field("mem0", description="Name for the vector collection")
|
||||
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
|
||||
index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
|
||||
index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_connection_string(cls, values):
|
||||
conn_str = values.get("connection_string")
|
||||
if not conn_str or not conn_str.startswith("postgresql://"):
|
||||
raise ValueError("A valid PostgreSQL connection string must be provided")
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
@@ -70,6 +70,7 @@ class VectorStoreFactory:
|
||||
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
||||
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
|
||||
"supabase": "mem0.vector_stores.supabase.Supabase",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -19,6 +19,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"redis": "RedisDBConfig",
|
||||
"elasticsearch": "ElasticsearchConfig",
|
||||
"opensearch": "OpenSearchConfig",
|
||||
"supabase": "SupabaseConfig",
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
231
mem0/vector_stores/supabase.py
Normal file
231
mem0/vector_stores/supabase.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
import vecs
|
||||
except ImportError:
|
||||
raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.")
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
from mem0.configs.vector_stores.supabase import IndexMethod, IndexMeasure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str]
|
||||
score: Optional[float]
|
||||
payload: Optional[dict]
|
||||
|
||||
|
||||
class Supabase(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
index_method: IndexMethod = IndexMethod.AUTO,
|
||||
index_measure: IndexMeasure = IndexMeasure.COSINE,
|
||||
):
|
||||
"""
|
||||
Initialize the Supabase vector store using vecs.
|
||||
|
||||
Args:
|
||||
connection_string (str): PostgreSQL connection string
|
||||
collection_name (str): Collection name
|
||||
embedding_model_dims (int): Dimension of the embedding vector
|
||||
index_method (IndexMethod): Index method to use. Defaults to AUTO.
|
||||
index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE.
|
||||
"""
|
||||
self.db = vecs.create_client(connection_string)
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.index_method = index_method
|
||||
self.index_measure = index_measure
|
||||
|
||||
collections = self.list_cols()
|
||||
if collection_name not in collections:
|
||||
self.create_col(embedding_model_dims)
|
||||
|
||||
def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
"""
|
||||
Preprocess filters to be compatible with vecs.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to preprocess. Multiple filters will be
|
||||
combined with AND logic.
|
||||
"""
|
||||
if filters is None:
|
||||
return None
|
||||
|
||||
if len(filters) == 1:
|
||||
# For single filter, keep the simple format
|
||||
key, value = next(iter(filters.items()))
|
||||
return {key: {"$eq": value}}
|
||||
|
||||
# For multiple filters, use $and clause
|
||||
return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]}
|
||||
|
||||
def create_col(self, embedding_model_dims: Optional[int] = None) -> None:
|
||||
"""
|
||||
Create a new collection with vector support.
|
||||
Will also initialize vector search index.
|
||||
|
||||
Args:
|
||||
embedding_model_dims (int, optional): Dimension of the embedding vector.
|
||||
If not provided, uses the dimension specified in initialization.
|
||||
"""
|
||||
dims = embedding_model_dims or self.embedding_model_dims
|
||||
if not dims:
|
||||
raise ValueError(
|
||||
"embedding_model_dims must be provided either during initialization or when creating collection"
|
||||
)
|
||||
|
||||
logger.info(f"Creating new collection: {self.collection_name}")
|
||||
try:
|
||||
self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims)
|
||||
self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value)
|
||||
logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create collection: {str(e)}")
|
||||
raise
|
||||
|
||||
def insert(
|
||||
self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
Insert vectors into the collection.
|
||||
|
||||
Args:
|
||||
vectors (List[List[float]]): List of vectors to insert
|
||||
payloads (List[Dict], optional): List of payloads corresponding to vectors
|
||||
ids (List[str], optional): List of IDs corresponding to vectors
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
|
||||
|
||||
if not ids:
|
||||
ids = [str(uuid.uuid4()) for _ in vectors]
|
||||
if not payloads:
|
||||
payloads = [{} for _ in vectors]
|
||||
|
||||
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
|
||||
print(records)
|
||||
|
||||
self.collection.upsert(records)
|
||||
|
||||
def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (List[float]): Query vector
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Search results
|
||||
"""
|
||||
filters = self._preprocess_filters(filters)
|
||||
print(filters)
|
||||
results = self.collection.query(
|
||||
data=query, limit=limit, filters=filters, include_metadata=True, include_value=True
|
||||
)
|
||||
print(results)
|
||||
|
||||
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
|
||||
|
||||
def delete(self, vector_id: str):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to delete
|
||||
"""
|
||||
self.collection.delete([(vector_id,)])
|
||||
|
||||
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None):
|
||||
"""
|
||||
Update a vector and/or its payload.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to update
|
||||
vector (List[float], optional): Updated vector
|
||||
payload (Dict, optional): Updated payload
|
||||
"""
|
||||
if vector is None:
|
||||
# If only updating metadata, we need to get the existing vector
|
||||
existing = self.get(vector_id)
|
||||
if existing and existing.payload:
|
||||
vector = existing.payload.get("vector", [])
|
||||
|
||||
if vector:
|
||||
self.collection.upsert([(vector_id, vector, payload or {})])
|
||||
|
||||
def get(self, vector_id: str) -> Optional[OutputData]:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (str): ID of the vector to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[OutputData]: Retrieved vector data or None if not found
|
||||
"""
|
||||
result = self.collection.fetch([(vector_id,)])
|
||||
if not result:
|
||||
return []
|
||||
|
||||
record = result[0]
|
||||
return OutputData(id=str(record.id), score=None, payload=record.metadata)
|
||||
|
||||
def list_cols(self) -> List[str]:
|
||||
"""
|
||||
List all collections.
|
||||
|
||||
Returns:
|
||||
List[str]: List of collection names
|
||||
"""
|
||||
return self.db.list_collections()
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete the collection."""
|
||||
self.db.delete_collection(self.collection_name)
|
||||
|
||||
def col_info(self) -> dict:
|
||||
"""
|
||||
Get information about the collection.
|
||||
|
||||
Returns:
|
||||
Dict: Collection information including name and configuration
|
||||
"""
|
||||
info = self.collection.describe()
|
||||
return {
|
||||
"name": info.name,
|
||||
"count": info.vectors,
|
||||
"dimension": info.dimension,
|
||||
"index": {"method": info.index_method, "metric": info.distance_metric},
|
||||
}
|
||||
|
||||
def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List vectors in the collection.
|
||||
|
||||
Args:
|
||||
filters (Dict, optional): Filters to apply
|
||||
limit (int, optional): Maximum number of results to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: List of vectors
|
||||
"""
|
||||
filters = self._preprocess_filters(filters)
|
||||
query = [0] * self.embedding_model_dims
|
||||
ids = self.collection.query(
|
||||
data=query, limit=limit, filters=filters, include_metadata=True, include_value=False
|
||||
)
|
||||
ids = [id[0] for id in ids]
|
||||
records = self.collection.fetch(ids=ids)
|
||||
|
||||
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
|
||||
178
tests/vector_stores/test_supabase.py
Normal file
178
tests/vector_stores/test_supabase.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod
|
||||
from mem0.vector_stores.supabase import Supabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vecs_client():
|
||||
with patch("vecs.create_client") as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_collection():
|
||||
collection = Mock()
|
||||
collection.name = "test_collection"
|
||||
collection.vectors = 100
|
||||
collection.dimension = 1536
|
||||
collection.index_method = "hnsw"
|
||||
collection.distance_metric = "cosine_distance"
|
||||
collection.describe.return_value = collection
|
||||
return collection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def supabase_instance(mock_vecs_client, mock_collection):
|
||||
# Set up the mock client to return our mock collection
|
||||
mock_vecs_client.return_value.get_or_create_collection.return_value = mock_collection
|
||||
mock_vecs_client.return_value.list_collections.return_value = ["test_collection"]
|
||||
|
||||
instance = Supabase(
|
||||
connection_string="postgresql://user:password@localhost:5432/test",
|
||||
collection_name="test_collection",
|
||||
embedding_model_dims=1536,
|
||||
index_method=IndexMethod.HNSW,
|
||||
index_measure=IndexMeasure.COSINE,
|
||||
)
|
||||
|
||||
# Manually set the collection attribute since we're mocking the initialization
|
||||
instance.collection = mock_collection
|
||||
return instance
|
||||
|
||||
|
||||
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
|
||||
supabase_instance.create_col(1536)
|
||||
|
||||
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(
|
||||
name="test_collection",
|
||||
dimension=1536
|
||||
)
|
||||
mock_collection.create_index.assert_called_with(
|
||||
method="hnsw",
|
||||
measure="cosine_distance"
|
||||
)
|
||||
|
||||
|
||||
def test_insert_vectors(supabase_instance, mock_collection):
|
||||
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||
ids = ["id1", "id2"]
|
||||
|
||||
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||
|
||||
expected_records = [
|
||||
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
||||
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
||||
]
|
||||
mock_collection.upsert.assert_called_once_with(expected_records)
|
||||
|
||||
|
||||
def test_search_vectors(supabase_instance, mock_collection):
|
||||
mock_results = [
|
||||
("id1", 0.9, {"name": "vector1"}),
|
||||
("id2", 0.8, {"name": "vector2"})
|
||||
]
|
||||
mock_collection.query.return_value = mock_results
|
||||
|
||||
query = [0.1, 0.2, 0.3]
|
||||
filters = {"category": "test"}
|
||||
results = supabase_instance.search(query=query, limit=2, filters=filters)
|
||||
|
||||
mock_collection.query.assert_called_once_with(
|
||||
data=query,
|
||||
limit=2,
|
||||
filters={"category": {"$eq": "test"}},
|
||||
include_metadata=True,
|
||||
include_value=True
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].score == 0.9
|
||||
assert results[0].payload == {"name": "vector1"}
|
||||
|
||||
|
||||
def test_delete_vector(supabase_instance, mock_collection):
|
||||
vector_id = "id1"
|
||||
supabase_instance.delete(vector_id=vector_id)
|
||||
mock_collection.delete.assert_called_once_with([("id1",)])
|
||||
|
||||
|
||||
def test_update_vector(supabase_instance, mock_collection):
|
||||
vector_id = "id1"
|
||||
new_vector = [0.7, 0.8, 0.9]
|
||||
new_payload = {"name": "updated_vector"}
|
||||
|
||||
supabase_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
|
||||
mock_collection.upsert.assert_called_once_with([("id1", new_vector, new_payload)])
|
||||
|
||||
|
||||
def test_get_vector(supabase_instance, mock_collection):
|
||||
# Create a Mock object to represent the record
|
||||
mock_record = Mock()
|
||||
mock_record.id = "id1"
|
||||
mock_record.metadata = {"name": "vector1"}
|
||||
mock_record.values = [0.1, 0.2, 0.3]
|
||||
|
||||
# Set the fetch return value to a list containing our mock record
|
||||
mock_collection.fetch.return_value = [mock_record]
|
||||
|
||||
result = supabase_instance.get(vector_id="id1")
|
||||
|
||||
mock_collection.fetch.assert_called_once_with([("id1",)])
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"name": "vector1"}
|
||||
|
||||
|
||||
def test_list_vectors(supabase_instance, mock_collection):
|
||||
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
|
||||
mock_fetch_results = [
|
||||
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
|
||||
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
|
||||
]
|
||||
|
||||
mock_collection.query.return_value = mock_query_results
|
||||
mock_collection.fetch.return_value = mock_fetch_results
|
||||
|
||||
results = supabase_instance.list(limit=2, filters={"category": "test"})
|
||||
|
||||
assert len(results[0]) == 2
|
||||
assert results[0][0].id == "id1"
|
||||
assert results[0][0].payload == {"name": "vector1"}
|
||||
assert results[0][1].id == "id2"
|
||||
assert results[0][1].payload == {"name": "vector2"}
|
||||
|
||||
|
||||
def test_col_info(supabase_instance, mock_collection):
|
||||
info = supabase_instance.col_info()
|
||||
|
||||
assert info == {
|
||||
"name": "test_collection",
|
||||
"count": 100,
|
||||
"dimension": 1536,
|
||||
"index": {
|
||||
"method": "hnsw",
|
||||
"metric": "cosine_distance"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_preprocess_filters(supabase_instance):
|
||||
# Test single filter
|
||||
single_filter = {"category": "test"}
|
||||
assert supabase_instance._preprocess_filters(single_filter) == {"category": {"$eq": "test"}}
|
||||
|
||||
# Test multiple filters
|
||||
multi_filter = {"category": "test", "type": "document"}
|
||||
assert supabase_instance._preprocess_filters(multi_filter) == {
|
||||
"$and": [
|
||||
{"category": {"$eq": "test"}},
|
||||
{"type": {"$eq": "document"}}
|
||||
]
|
||||
}
|
||||
|
||||
# Test None filters
|
||||
assert supabase_instance._preprocess_filters(None) is None
|
||||
Reference in New Issue
Block a user