diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a8fe615..34c8e9c4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,7 +52,7 @@ jobs: virtualenvs-in-project: true - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: .venv key: venv-mem0-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} @@ -83,7 +83,7 @@ jobs: virtualenvs-in-project: true - name: Load cached venv id: cached-poetry-dependencies - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: .venv key: venv-embedchain-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} diff --git a/Makefile b/Makefile index 5acac2c6..dae34e4a 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ install: install_all: poetry install poetry run pip install groq together boto3 litellm ollama chromadb sentence_transformers vertexai \ - google-generativeai elasticsearch opensearch-py + google-generativeai elasticsearch opensearch-py vecs # Format code with ruff format: diff --git a/docs/components/vectordbs/config.mdx b/docs/components/vectordbs/config.mdx index 61275a21..5daf4048 100644 --- a/docs/components/vectordbs/config.mdx +++ b/docs/components/vectordbs/config.mdx @@ -86,6 +86,9 @@ Here's a comprehensive list of all parameters that can be used across different | `url` | Full URL for the server | | `api_key` | API key for the server | | `on_disk` | Enable persistent storage | +| `connection_string` | PostgreSQL connection string (for Supabase/PGVector) | +| `index_method` | Vector index method (for Supabase) | +| `index_measure` | Distance measure for similarity search (for Supabase) | | Parameter | Description | diff --git a/docs/components/vectordbs/dbs/supabase.mdx b/docs/components/vectordbs/dbs/supabase.mdx new file mode 100644 index 00000000..87fbe44e --- /dev/null +++ b/docs/components/vectordbs/dbs/supabase.mdx @@ -0,0 +1,78 @@ +[Supabase](https://supabase.com/) is an open-source Firebase alternative that provides a PostgreSQL database with pgvector extension for vector similarity search. It offers a powerful and scalable solution for storing and querying vector embeddings. + +Create a [Supabase](https://supabase.com/dashboard/projects) account and project, then get your connection string from Project Settings > Database. See the [docs](https://supabase.github.io/vecs/hosting/) for details. + +### Usage + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "sk-xx" + +config = { + "vector_store": { + "provider": "supabase", + "config": { + "connection_string": "postgresql://user:password@host:port/database", + "collection_name": "memories", + "index_method": "hnsw", # Optional: defaults to "auto" + "index_measure": "cosine_distance" # Optional: defaults to "cosine_distance" + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + +### Config + +Here are the parameters available for configuring Supabase: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `connection_string` | PostgreSQL connection string (required) | None | +| `collection_name` | Name for the vector collection | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `index_method` | Vector index method to use | `auto` | +| `index_measure` | Distance measure for similarity search | `cosine_distance` | + +### Index Methods + +The following index methods are supported: + +- `auto`: Automatically selects the best available index method +- `hnsw`: Hierarchical Navigable Small World graph index (faster search, more memory usage) +- `ivfflat`: Inverted File Flat index (good balance of speed and memory) + +### Distance Measures + +Available distance measures for similarity search: + +- `cosine_distance`: Cosine similarity (recommended for most embedding models) +- `l2_distance`: Euclidean distance +- `l1_distance`: Manhattan distance +- `max_inner_product`: Maximum inner product similarity + +### Best Practices + +1. **Index Method Selection**: + - Use `hnsw` for fastest search performance when memory is not a constraint + - Use `ivfflat` for a good balance of search speed and memory usage + - Use `auto` if unsure, it will select the best method based on your data + +2. **Distance Measure Selection**: + - Use `cosine_distance` for most embedding models (OpenAI, Hugging Face, etc.) + - Use `max_inner_product` if your vectors are normalized + - Use `l2_distance` or `l1_distance` if working with raw feature vectors + +3. **Connection String**: + - Always use environment variables for sensitive information in the connection string + - Format: `postgresql://user:password@host:port/database` diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index 773ffb97..37e35c5d 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -23,6 +23,7 @@ See the list of supported vector databases below. + ## Usage diff --git a/docs/docs.json b/docs/docs.json index c88082f2..78384401 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -128,7 +128,8 @@ "components/vectordbs/dbs/azure_ai_search", "components/vectordbs/dbs/redis", "components/vectordbs/dbs/elasticsearch", - "components/vectordbs/dbs/opensearch" + "components/vectordbs/dbs/opensearch", + "components/vectordbs/dbs/supabase" ] } ] diff --git a/mem0/configs/vector_stores/supabase.py b/mem0/configs/vector_stores/supabase.py new file mode 100644 index 00000000..7a43afc8 --- /dev/null +++ b/mem0/configs/vector_stores/supabase.py @@ -0,0 +1,44 @@ +from typing import Any, Dict, Optional +from enum import Enum + +from pydantic import BaseModel, Field, model_validator + + +class IndexMethod(str, Enum): + AUTO = "auto" + HNSW = "hnsw" + IVFFLAT = "ivfflat" + + +class IndexMeasure(str, Enum): + COSINE = "cosine_distance" + L2 = "l2_distance" + L1 = "l1_distance" + MAX_INNER_PRODUCT = "max_inner_product" + + +class SupabaseConfig(BaseModel): + connection_string: str = Field(..., description="PostgreSQL connection string") + collection_name: str = Field("mem0", description="Name for the vector collection") + embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") + index_method: Optional[IndexMethod] = Field(IndexMethod.AUTO, description="Index method to use") + index_measure: Optional[IndexMeasure] = Field(IndexMeasure.COSINE, description="Distance measure to use") + + @model_validator(mode="before") + def check_connection_string(cls, values): + conn_str = values.get("connection_string") + if not conn_str or not conn_str.startswith("postgresql://"): + raise ValueError("A valid PostgreSQL connection string must be provided") + return values + + @model_validator(mode="before") + @classmethod + def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + allowed_fields = set(cls.model_fields.keys()) + input_fields = set(values.keys()) + extra_fields = input_fields - allowed_fields + if extra_fields: + raise ValueError( + f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}" + ) + return values diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 82af19d5..ca2e051e 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -70,6 +70,7 @@ class VectorStoreFactory: "redis": "mem0.vector_stores.redis.RedisDB", "elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB", "opensearch": "mem0.vector_stores.opensearch.OpenSearchDB", + "supabase": "mem0.vector_stores.supabase.Supabase", } @classmethod diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index b6d1c86a..9271fb58 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -19,6 +19,7 @@ class VectorStoreConfig(BaseModel): "redis": "RedisDBConfig", "elasticsearch": "ElasticsearchConfig", "opensearch": "OpenSearchConfig", + "supabase": "SupabaseConfig", } @model_validator(mode="after") diff --git a/mem0/vector_stores/supabase.py b/mem0/vector_stores/supabase.py new file mode 100644 index 00000000..765c2194 --- /dev/null +++ b/mem0/vector_stores/supabase.py @@ -0,0 +1,231 @@ +import logging +import uuid +from typing import List, Optional, Dict, Any + +from pydantic import BaseModel + +try: + import vecs +except ImportError: + raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.") + +from mem0.vector_stores.base import VectorStoreBase +from mem0.configs.vector_stores.supabase import IndexMethod, IndexMeasure + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] + score: Optional[float] + payload: Optional[dict] + + +class Supabase(VectorStoreBase): + def __init__( + self, + connection_string: str, + collection_name: str, + embedding_model_dims: int, + index_method: IndexMethod = IndexMethod.AUTO, + index_measure: IndexMeasure = IndexMeasure.COSINE, + ): + """ + Initialize the Supabase vector store using vecs. + + Args: + connection_string (str): PostgreSQL connection string + collection_name (str): Collection name + embedding_model_dims (int): Dimension of the embedding vector + index_method (IndexMethod): Index method to use. Defaults to AUTO. + index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE. + """ + self.db = vecs.create_client(connection_string) + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.index_method = index_method + self.index_measure = index_measure + + collections = self.list_cols() + if collection_name not in collections: + self.create_col(embedding_model_dims) + + def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]: + """ + Preprocess filters to be compatible with vecs. + + Args: + filters (Dict, optional): Filters to preprocess. Multiple filters will be + combined with AND logic. + """ + if filters is None: + return None + + if len(filters) == 1: + # For single filter, keep the simple format + key, value = next(iter(filters.items())) + return {key: {"$eq": value}} + + # For multiple filters, use $and clause + return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]} + + def create_col(self, embedding_model_dims: Optional[int] = None) -> None: + """ + Create a new collection with vector support. + Will also initialize vector search index. + + Args: + embedding_model_dims (int, optional): Dimension of the embedding vector. + If not provided, uses the dimension specified in initialization. + """ + dims = embedding_model_dims or self.embedding_model_dims + if not dims: + raise ValueError( + "embedding_model_dims must be provided either during initialization or when creating collection" + ) + + logger.info(f"Creating new collection: {self.collection_name}") + try: + self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims) + self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value) + logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}") + except Exception as e: + logger.error(f"Failed to create collection: {str(e)}") + raise + + def insert( + self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None + ): + """ + Insert vectors into the collection. + + Args: + vectors (List[List[float]]): List of vectors to insert + payloads (List[Dict], optional): List of payloads corresponding to vectors + ids (List[str], optional): List of IDs corresponding to vectors + """ + logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") + + if not ids: + ids = [str(uuid.uuid4()) for _ in vectors] + if not payloads: + payloads = [{} for _ in vectors] + + records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)] + print(records) + + self.collection.upsert(records) + + def search(self, query: List[float], limit: int = 5, filters: Optional[dict] = None) -> List[OutputData]: + """ + Search for similar vectors. + + Args: + query (List[float]): Query vector + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. Defaults to None. + + Returns: + List[OutputData]: Search results + """ + filters = self._preprocess_filters(filters) + print(filters) + results = self.collection.query( + data=query, limit=limit, filters=filters, include_metadata=True, include_value=True + ) + print(results) + + return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results] + + def delete(self, vector_id: str): + """ + Delete a vector by ID. + + Args: + vector_id (str): ID of the vector to delete + """ + self.collection.delete([(vector_id,)]) + + def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None): + """ + Update a vector and/or its payload. + + Args: + vector_id (str): ID of the vector to update + vector (List[float], optional): Updated vector + payload (Dict, optional): Updated payload + """ + if vector is None: + # If only updating metadata, we need to get the existing vector + existing = self.get(vector_id) + if existing and existing.payload: + vector = existing.payload.get("vector", []) + + if vector: + self.collection.upsert([(vector_id, vector, payload or {})]) + + def get(self, vector_id: str) -> Optional[OutputData]: + """ + Retrieve a vector by ID. + + Args: + vector_id (str): ID of the vector to retrieve + + Returns: + Optional[OutputData]: Retrieved vector data or None if not found + """ + result = self.collection.fetch([(vector_id,)]) + if not result: + return [] + + record = result[0] + return OutputData(id=str(record.id), score=None, payload=record.metadata) + + def list_cols(self) -> List[str]: + """ + List all collections. + + Returns: + List[str]: List of collection names + """ + return self.db.list_collections() + + def delete_col(self): + """Delete the collection.""" + self.db.delete_collection(self.collection_name) + + def col_info(self) -> dict: + """ + Get information about the collection. + + Returns: + Dict: Collection information including name and configuration + """ + info = self.collection.describe() + return { + "name": info.name, + "count": info.vectors, + "dimension": info.dimension, + "index": {"method": info.index_method, "metric": info.distance_metric}, + } + + def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]: + """ + List vectors in the collection. + + Args: + filters (Dict, optional): Filters to apply + limit (int, optional): Maximum number of results to return. Defaults to 100. + + Returns: + List[OutputData]: List of vectors + """ + filters = self._preprocess_filters(filters) + query = [0] * self.embedding_model_dims + ids = self.collection.query( + data=query, limit=limit, filters=filters, include_metadata=True, include_value=False + ) + ids = [id[0] for id in ids] + records = self.collection.fetch(ids=ids) + + return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]] diff --git a/tests/vector_stores/test_supabase.py b/tests/vector_stores/test_supabase.py new file mode 100644 index 00000000..bbdb468f --- /dev/null +++ b/tests/vector_stores/test_supabase.py @@ -0,0 +1,178 @@ +from unittest.mock import Mock, patch + +import pytest + +from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod +from mem0.vector_stores.supabase import Supabase + + +@pytest.fixture +def mock_vecs_client(): + with patch("vecs.create_client") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_collection(): + collection = Mock() + collection.name = "test_collection" + collection.vectors = 100 + collection.dimension = 1536 + collection.index_method = "hnsw" + collection.distance_metric = "cosine_distance" + collection.describe.return_value = collection + return collection + + +@pytest.fixture +def supabase_instance(mock_vecs_client, mock_collection): + # Set up the mock client to return our mock collection + mock_vecs_client.return_value.get_or_create_collection.return_value = mock_collection + mock_vecs_client.return_value.list_collections.return_value = ["test_collection"] + + instance = Supabase( + connection_string="postgresql://user:password@localhost:5432/test", + collection_name="test_collection", + embedding_model_dims=1536, + index_method=IndexMethod.HNSW, + index_measure=IndexMeasure.COSINE, + ) + + # Manually set the collection attribute since we're mocking the initialization + instance.collection = mock_collection + return instance + + +def test_create_col(supabase_instance, mock_vecs_client, mock_collection): + supabase_instance.create_col(1536) + + mock_vecs_client.return_value.get_or_create_collection.assert_called_with( + name="test_collection", + dimension=1536 + ) + mock_collection.create_index.assert_called_with( + method="hnsw", + measure="cosine_distance" + ) + + +def test_insert_vectors(supabase_instance, mock_collection): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + + supabase_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + + expected_records = [ + ("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), + ("id2", [0.4, 0.5, 0.6], {"name": "vector2"}) + ] + mock_collection.upsert.assert_called_once_with(expected_records) + + +def test_search_vectors(supabase_instance, mock_collection): + mock_results = [ + ("id1", 0.9, {"name": "vector1"}), + ("id2", 0.8, {"name": "vector2"}) + ] + mock_collection.query.return_value = mock_results + + query = [0.1, 0.2, 0.3] + filters = {"category": "test"} + results = supabase_instance.search(query=query, limit=2, filters=filters) + + mock_collection.query.assert_called_once_with( + data=query, + limit=2, + filters={"category": {"$eq": "test"}}, + include_metadata=True, + include_value=True + ) + + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.9 + assert results[0].payload == {"name": "vector1"} + + +def test_delete_vector(supabase_instance, mock_collection): + vector_id = "id1" + supabase_instance.delete(vector_id=vector_id) + mock_collection.delete.assert_called_once_with([("id1",)]) + + +def test_update_vector(supabase_instance, mock_collection): + vector_id = "id1" + new_vector = [0.7, 0.8, 0.9] + new_payload = {"name": "updated_vector"} + + supabase_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload) + mock_collection.upsert.assert_called_once_with([("id1", new_vector, new_payload)]) + + +def test_get_vector(supabase_instance, mock_collection): + # Create a Mock object to represent the record + mock_record = Mock() + mock_record.id = "id1" + mock_record.metadata = {"name": "vector1"} + mock_record.values = [0.1, 0.2, 0.3] + + # Set the fetch return value to a list containing our mock record + mock_collection.fetch.return_value = [mock_record] + + result = supabase_instance.get(vector_id="id1") + + mock_collection.fetch.assert_called_once_with([("id1",)]) + assert result.id == "id1" + assert result.payload == {"name": "vector1"} + + +def test_list_vectors(supabase_instance, mock_collection): + mock_query_results = [("id1", 0.9, {}), ("id2", 0.8, {})] + mock_fetch_results = [ + ("id1", [0.1, 0.2, 0.3], {"name": "vector1"}), + ("id2", [0.4, 0.5, 0.6], {"name": "vector2"}) + ] + + mock_collection.query.return_value = mock_query_results + mock_collection.fetch.return_value = mock_fetch_results + + results = supabase_instance.list(limit=2, filters={"category": "test"}) + + assert len(results[0]) == 2 + assert results[0][0].id == "id1" + assert results[0][0].payload == {"name": "vector1"} + assert results[0][1].id == "id2" + assert results[0][1].payload == {"name": "vector2"} + + +def test_col_info(supabase_instance, mock_collection): + info = supabase_instance.col_info() + + assert info == { + "name": "test_collection", + "count": 100, + "dimension": 1536, + "index": { + "method": "hnsw", + "metric": "cosine_distance" + } + } + + +def test_preprocess_filters(supabase_instance): + # Test single filter + single_filter = {"category": "test"} + assert supabase_instance._preprocess_filters(single_filter) == {"category": {"$eq": "test"}} + + # Test multiple filters + multi_filter = {"category": "test", "type": "document"} + assert supabase_instance._preprocess_filters(multi_filter) == { + "$and": [ + {"category": {"$eq": "test"}}, + {"type": {"$eq": "document"}} + ] + } + + # Test None filters + assert supabase_instance._preprocess_filters(None) is None