Integrate Supabase VectorDB (#2290)

This commit is contained in:
Dev Khant
2025-03-03 23:16:24 +05:30
committed by GitHub
parent 2556c5fe88
commit 8452dd598f
11 changed files with 542 additions and 4 deletions

View File

@@ -52,7 +52,7 @@ jobs:
virtualenvs-in-project: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: .venv
key: venv-mem0-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
@@ -83,7 +83,7 @@ jobs:
virtualenvs-in-project: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v2
uses: actions/cache@v3
with:
path: .venv
key: venv-embedchain-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}

View File

@@ -13,7 +13,7 @@ install:
install_all:
poetry install
poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \
google-generativeai elasticsearch opensearch-py
google-generativeai elasticsearch opensearch-py vecs
# Format code with ruff
format:

View File

@@ -86,6 +86,9 @@ Here's a comprehensive list of all parameters that can be used across different
| `url` | Full URL for the server |
| `api_key` | API key for the server |
| `on_disk` | Enable persistent storage |
| `connection_string` | PostgreSQL connection string (for Supabase/PGVector) |
| `index_method` | Vector index method (for Supabase) |
| `index_measure` | Distance measure for similarity search (for Supabase) |
</Tab>
<Tab title="TypeScript">
| Parameter | Description |

View File

@@ -0,0 +1,78 @@
[Supabase](https://supabase.com/) is an open-source Firebase alternative that provides a PostgreSQL database with pgvector extension for vector similarity search. It offers a powerful and scalable solution for storing and querying vector embeddings.
Create a [Supabase](https://supabase.com/dashboard/projects) account and project, then get your connection string from Project Settings > Database. See the [docs](https://supabase.github.io/vecs/hosting/) for details.
### Usage
```python
import os
from mem0 import Memory
os.environ["OPENAI_API_KEY"] = "sk-xx"
config = {
"vector_store": {
"provider": "supabase",
"config": {
"connection_string": "postgresql://user:password@host:port/database",
"collection_name": "memories",
"index_method": "hnsw", # Optional: defaults to "auto"
"index_measure": "cosine_distance" # Optional: defaults to "cosine_distance"
}
}
}
m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```
### Config
Here are the parameters available for configuring Supabase:
| Parameter | Description | Default Value |
| --- | --- | --- |
| `connection_string` | PostgreSQL connection string (required) | None |
| `collection_name` | Name for the vector collection | `mem0` |
| `embedding_model_dims` | Dimensions of the embedding model | `1536` |
| `index_method` | Vector index method to use | `auto` |
| `index_measure` | Distance measure for similarity search | `cosine_distance` |
### Index Methods
The following index methods are supported:
- `auto`: Automatically selects the best available index method
- `hnsw`: Hierarchical Navigable Small World graph index (faster search, more memory usage)
- `ivfflat`: Inverted File Flat index (good balance of speed and memory)
### Distance Measures
Available distance measures for similarity search:
- `cosine_distance`: Cosine similarity (recommended for most embedding models)
- `l2_distance`: Euclidean distance
- `l1_distance`: Manhattan distance
- `max_inner_product`: Maximum inner product similarity
### Best Practices
1. **Index Method Selection**:
- Use `hnsw` for fastest search performance when memory is not a constraint
- Use `ivfflat` for a good balance of search speed and memory usage
- Use `auto` if unsure, it will select the best method based on your data
2. **Distance Measure Selection**:
- Use `cosine_distance` for most embedding models (OpenAI, Hugging Face, etc.)
- Use `max_inner_product` if your vectors are normalized
- Use `l2_distance` or `l1_distance` if working with raw feature vectors
3. **Connection String**:
- Always use environment variables for sensitive information in the connection string
- Format: `postgresql://user:password@host:port/database`

View File

@@ -23,6 +23,7 @@ See the list of supported vector databases below.
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
<Card title="Elasticsearch" href="/components/vectordbs/dbs/elasticsearch"></Card>
<Card title="OpenSearch" href="/components/vectordbs/dbs/opensearch"></Card>
<Card title="Supabase" href="/components/vectordbs/dbs/supabase"></Card>
</CardGroup>
## Usage

View File

@@ -128,7 +128,8 @@
"components/vectordbs/dbs/azure_ai_search",
"components/vectordbs/dbs/redis",
"components/vectordbs/dbs/elasticsearch",
"components/vectordbs/dbs/opensearch"
"components/vectordbs/dbs/opensearch",
"components/vectordbs/dbs/supabase"
]
}
]

View File

@@ -0,0 +1,44 @@
from typing import Any, Dict, Optional
from enum import Enum
from pydantic import BaseModel, Field, model_validator
class IndexMethod(str, Enum):
AUTO = "auto"
HNSW = "hnsw"
IVFFLAT = "ivfflat"
class IndexMeasure(str, Enum):
COSINE = "cosine_distance"
L2 = "l2_distance"
L1 = "l1_distance"
MAX_INNER_PRODUCT = "max_inner_product"
class SupabaseConfig(BaseModel):
connection_string: str = Field(..., description="PostgreSQL connection string")
collection_name: str = Field("mem0", description="Name for the vector collection")
embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model")
index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use")
index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use")
@model_validator(mode="before")
def check_connection_string(cls, values):
conn_str = values.get("connection_string")
if not conn_str or not conn_str.startswith("postgresql://"):
raise ValueError("A valid PostgreSQL connection string must be provided")
return values
@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields
if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
)
return values

View File

@@ -70,6 +70,7 @@ class VectorStoreFactory:
"redis": "mem0.vector_stores.redis.RedisDB",
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
"opensearch": "mem0.vector_stores.opensearch.OpenSearchDB",
"supabase": "mem0.vector_stores.supabase.Supabase",
}
@classmethod

View File

@@ -19,6 +19,7 @@ class VectorStoreConfig(BaseModel):
"redis": "RedisDBConfig",
"elasticsearch": "ElasticsearchConfig",
"opensearch": "OpenSearchConfig",
"supabase": "SupabaseConfig",
}
@model_validator(mode="after")

View File

@@ -0,0 +1,231 @@
import logging
import uuid
from typing import List, Optional, Dict, Any
from pydantic import BaseModel
try:
import vecs
except ImportError:
raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.")
from mem0.vector_stores.base import VectorStoreBase
from mem0.configs.vector_stores.supabase import IndexMethod, IndexMeasure
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class Supabase(VectorStoreBase):
def __init__(
self,
connection_string: str,
collection_name: str,
embedding_model_dims: int,
index_method: IndexMethod = IndexMethod.AUTO,
index_measure: IndexMeasure = IndexMeasure.COSINE,
):
"""
Initialize the Supabase vector store using vecs.
Args:
connection_string (str): PostgreSQL connection string
collection_name (str): Collection name
embedding_model_dims (int): Dimension of the embedding vector
index_method (IndexMethod): Index method to use. Defaults to AUTO.
index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE.
"""
self.db = vecs.create_client(connection_string)
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.index_method = index_method
self.index_measure = index_measure
collections = self.list_cols()
if collection_name not in collections:
self.create_col(embedding_model_dims)
def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]:
"""
Preprocess filters to be compatible with vecs.
Args:
filters (Dict, optional): Filters to preprocess. Multiple filters will be
combined with AND logic.
"""
if filters is None:
return None
if len(filters) == 1:
# For single filter, keep the simple format
key, value = next(iter(filters.items()))
return {key: {"$eq": value}}
# For multiple filters, use $and clause
return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]}
def create_col(self, embedding_model_dims: Optional[int] = None) -> None:
"""
Create a new collection with vector support.
Will also initialize vector search index.
Args:
embedding_model_dims (int, optional): Dimension of the embedding vector.
If not provided, uses the dimension specified in initialization.
"""
dims = embedding_model_dims or self.embedding_model_dims
if not dims:
raise ValueError(
"embedding_model_dims must be provided either during initialization or when creating collection"
)
logger.info(f"Creating new collection: {self.collection_name}")
try:
self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims)
self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value)
logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}")
except Exception as e:
logger.error(f"Failed to create collection: {str(e)}")
raise
def insert(
self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None
):
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert
payloads (List[Dict], optional): List of payloads corresponding to vectors
ids (List[str], optional): List of IDs corresponding to vectors
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
if not ids:
ids = [str(uuid.uuid4()) for _ in vectors]
if not payloads:
payloads = [{} for _ in vectors]
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
print(records)
self.collection.upsert(records)
def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (List[float]): Query vector
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results
"""
filters = self._preprocess_filters(filters)
print(filters)
results = self.collection.query(
data=query, limit=limit, filters=filters, include_metadata=True, include_value=True
)
print(results)
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete
"""
self.collection.delete([(vector_id,)])
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None):
"""
Update a vector and/or its payload.
Args:
vector_id (str): ID of the vector to update
vector (List[float], optional): Updated vector
payload (Dict, optional): Updated payload
"""
if vector is None:
# If only updating metadata, we need to get the existing vector
existing = self.get(vector_id)
if existing and existing.payload:
vector = existing.payload.get("vector", [])
if vector:
self.collection.upsert([(vector_id, vector, payload or {})])
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve
Returns:
Optional[OutputData]: Retrieved vector data or None if not found
"""
result = self.collection.fetch([(vector_id,)])
if not result:
return []
record = result[0]
return OutputData(id=str(record.id), score=None, payload=record.metadata)
def list_cols(self) -> List[str]:
"""
List all collections.
Returns:
List[str]: List of collection names
"""
return self.db.list_collections()
def delete_col(self):
"""Delete the collection."""
self.db.delete_collection(self.collection_name)
def col_info(self) -> dict:
"""
Get information about the collection.
Returns:
Dict: Collection information including name and configuration
"""
info = self.collection.describe()
return {
"name": info.name,
"count": info.vectors,
"dimension": info.dimension,
"index": {"method": info.index_method, "metric": info.distance_metric},
}
def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]:
"""
List vectors in the collection.
Args:
filters (Dict, optional): Filters to apply
limit (int, optional): Maximum number of results to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors
"""
filters = self._preprocess_filters(filters)
query = [0] * self.embedding_model_dims
ids = self.collection.query(
data=query, limit=limit, filters=filters, include_metadata=True, include_value=False
)
ids = [id[0] for id in ids]
records = self.collection.fetch(ids=ids)
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]

View File

@@ -0,0 +1,178 @@
from unittest.mock import Mock, patch
import pytest
from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod
from mem0.vector_stores.supabase import Supabase
@pytest.fixture
def mock_vecs_client():
with patch("vecs.create_client") as mock_client:
yield mock_client
@pytest.fixture
def mock_collection():
collection = Mock()
collection.name = "test_collection"
collection.vectors = 100
collection.dimension = 1536
collection.index_method = "hnsw"
collection.distance_metric = "cosine_distance"
collection.describe.return_value = collection
return collection
@pytest.fixture
def supabase_instance(mock_vecs_client, mock_collection):
# Set up the mock client to return our mock collection
mock_vecs_client.return_value.get_or_create_collection.return_value = mock_collection
mock_vecs_client.return_value.list_collections.return_value = ["test_collection"]
instance = Supabase(
connection_string="postgresql://user:password@localhost:5432/test",
collection_name="test_collection",
embedding_model_dims=1536,
index_method=IndexMethod.HNSW,
index_measure=IndexMeasure.COSINE,
)
# Manually set the collection attribute since we're mocking the initialization
instance.collection = mock_collection
return instance
def test_create_col(supabase_instance, mock_vecs_client, mock_collection):
supabase_instance.create_col(1536)
mock_vecs_client.return_value.get_or_create_collection.assert_called_with(
name="test_collection",
dimension=1536
)
mock_collection.create_index.assert_called_with(
method="hnsw",
measure="cosine_distance"
)
def test_insert_vectors(supabase_instance, mock_collection):
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
payloads = [{"name": "vector1"}, {"name": "vector2"}]
ids = ["id1", "id2"]
supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
expected_records = [
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
mock_collection.upsert.assert_called_once_with(expected_records)
def test_search_vectors(supabase_instance, mock_collection):
mock_results = [
("id1", 0.9, {"name": "vector1"}),
("id2", 0.8, {"name": "vector2"})
]
mock_collection.query.return_value = mock_results
query = [0.1, 0.2, 0.3]
filters = {"category": "test"}
results = supabase_instance.search(query=query, limit=2, filters=filters)
mock_collection.query.assert_called_once_with(
data=query,
limit=2,
filters={"category": {"$eq": "test"}},
include_metadata=True,
include_value=True
)
assert len(results) == 2
assert results[0].id == "id1"
assert results[0].score == 0.9
assert results[0].payload == {"name": "vector1"}
def test_delete_vector(supabase_instance, mock_collection):
vector_id = "id1"
supabase_instance.delete(vector_id=vector_id)
mock_collection.delete.assert_called_once_with([("id1",)])
def test_update_vector(supabase_instance, mock_collection):
vector_id = "id1"
new_vector = [0.7, 0.8, 0.9]
new_payload = {"name": "updated_vector"}
supabase_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
mock_collection.upsert.assert_called_once_with([("id1", new_vector, new_payload)])
def test_get_vector(supabase_instance, mock_collection):
# Create a Mock object to represent the record
mock_record = Mock()
mock_record.id = "id1"
mock_record.metadata = {"name": "vector1"}
mock_record.values = [0.1, 0.2, 0.3]
# Set the fetch return value to a list containing our mock record
mock_collection.fetch.return_value = [mock_record]
result = supabase_instance.get(vector_id="id1")
mock_collection.fetch.assert_called_once_with([("id1",)])
assert result.id == "id1"
assert result.payload == {"name": "vector1"}
def test_list_vectors(supabase_instance, mock_collection):
mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})]
mock_fetch_results = [
("id1", [0.1, 0.2, 0.3], {"name": "vector1"}),
("id2", [0.4, 0.5, 0.6], {"name": "vector2"})
]
mock_collection.query.return_value = mock_query_results
mock_collection.fetch.return_value = mock_fetch_results
results = supabase_instance.list(limit=2, filters={"category": "test"})
assert len(results[0]) == 2
assert results[0][0].id == "id1"
assert results[0][0].payload == {"name": "vector1"}
assert results[0][1].id == "id2"
assert results[0][1].payload == {"name": "vector2"}
def test_col_info(supabase_instance, mock_collection):
info = supabase_instance.col_info()
assert info == {
"name": "test_collection",
"count": 100,
"dimension": 1536,
"index": {
"method": "hnsw",
"metric": "cosine_distance"
}
}
def test_preprocess_filters(supabase_instance):
# Test single filter
single_filter = {"category": "test"}
assert supabase_instance._preprocess_filters(single_filter) == {"category": {"$eq": "test"}}
# Test multiple filters
multi_filter = {"category": "test", "type": "document"}
assert supabase_instance._preprocess_filters(multi_filter) == {
"$and": [
{"category": {"$eq": "test"}},
{"type": {"$eq": "document"}}
]
}
# Test None filters
assert supabase_instance._preprocess_filters(None) is None