Files
t6_mem0/mem0/vector_stores/supabase.py
2025-05-22 01:17:29 +05:30

238 lines
8.2 KiB
Python

import logging
import uuid
from typing import List, Optional
from pydantic import BaseModel
try:
import vecs
except ImportError:
raise ImportError("The 'vecs' library is required. Please install it using 'pip install vecs'.")
from mem0.configs.vector_stores.supabase import IndexMeasure, IndexMethod
from mem0.vector_stores.base import VectorStoreBase
logger = logging.getLogger(__name__)
class OutputData(BaseModel):
id: Optional[str]
score: Optional[float]
payload: Optional[dict]
class Supabase(VectorStoreBase):
def __init__(
self,
connection_string: str,
collection_name: str,
embedding_model_dims: int,
index_method: IndexMethod = IndexMethod.AUTO,
index_measure: IndexMeasure = IndexMeasure.COSINE,
):
"""
Initialize the Supabase vector store using vecs.
Args:
connection_string (str): PostgreSQL connection string
collection_name (str): Collection name
embedding_model_dims (int): Dimension of the embedding vector
index_method (IndexMethod): Index method to use. Defaults to AUTO.
index_measure (IndexMeasure): Distance measure to use. Defaults to COSINE.
"""
self.db = vecs.create_client(connection_string)
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
self.index_method = index_method
self.index_measure = index_measure
collections = self.list_cols()
if collection_name not in collections:
self.create_col(embedding_model_dims)
def _preprocess_filters(self, filters: Optional[dict] = None) -> Optional[dict]:
"""
Preprocess filters to be compatible with vecs.
Args:
filters (Dict, optional): Filters to preprocess. Multiple filters will be
combined with AND logic.
"""
if filters is None:
return None
if len(filters) == 1:
# For single filter, keep the simple format
key, value = next(iter(filters.items()))
return {key: {"$eq": value}}
# For multiple filters, use $and clause
return {"$and": [{key: {"$eq": value}} for key, value in filters.items()]}
def create_col(self, embedding_model_dims: Optional[int] = None) -> None:
"""
Create a new collection with vector support.
Will also initialize vector search index.
Args:
embedding_model_dims (int, optional): Dimension of the embedding vector.
If not provided, uses the dimension specified in initialization.
"""
dims = embedding_model_dims or self.embedding_model_dims
if not dims:
raise ValueError(
"embedding_model_dims must be provided either during initialization or when creating collection"
)
logger.info(f"Creating new collection: {self.collection_name}")
try:
self.collection = self.db.get_or_create_collection(name=self.collection_name, dimension=dims)
self.collection.create_index(method=self.index_method.value, measure=self.index_measure.value)
logger.info(f"Successfully created collection {self.collection_name} with dimension {dims}")
except Exception as e:
logger.error(f"Failed to create collection: {str(e)}")
raise
def insert(
self, vectors: List[List[float]], payloads: Optional[List[dict]] = None, ids: Optional[List[str]] = None
):
"""
Insert vectors into the collection.
Args:
vectors (List[List[float]]): List of vectors to insert
payloads (List[Dict], optional): List of payloads corresponding to vectors
ids (List[str], optional): List of IDs corresponding to vectors
"""
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
if not ids:
ids = [str(uuid.uuid4()) for _ in vectors]
if not payloads:
payloads = [{} for _ in vectors]
records = [(id, vector, payload) for id, vector, payload in zip(ids, vectors, payloads)]
self.collection.upsert(records)
def search(
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.
Args:
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.
Returns:
List[OutputData]: Search results
"""
filters = self._preprocess_filters(filters)
results = self.collection.query(
data=vectors, limit=limit, filters=filters, include_metadata=True, include_value=True
)
return [OutputData(id=str(result[0]), score=float(result[1]), payload=result[2]) for result in results]
def delete(self, vector_id: str):
"""
Delete a vector by ID.
Args:
vector_id (str): ID of the vector to delete
"""
self.collection.delete([(vector_id,)])
def update(self, vector_id: str, vector: Optional[List[float]] = None, payload: Optional[dict] = None):
"""
Update a vector and/or its payload.
Args:
vector_id (str): ID of the vector to update
vector (List[float], optional): Updated vector
payload (Dict, optional): Updated payload
"""
if vector is None:
# If only updating metadata, we need to get the existing vector
existing = self.get(vector_id)
if existing and existing.payload:
vector = existing.payload.get("vector", [])
if vector:
self.collection.upsert([(vector_id, vector, payload or {})])
def get(self, vector_id: str) -> Optional[OutputData]:
"""
Retrieve a vector by ID.
Args:
vector_id (str): ID of the vector to retrieve
Returns:
Optional[OutputData]: Retrieved vector data or None if not found
"""
result = self.collection.fetch([(vector_id,)])
if not result:
return []
record = result[0]
return OutputData(id=str(record.id), score=None, payload=record.metadata)
def list_cols(self) -> List[str]:
"""
List all collections.
Returns:
List[str]: List of collection names
"""
return self.db.list_collections()
def delete_col(self):
"""Delete the collection."""
self.db.delete_collection(self.collection_name)
def col_info(self) -> dict:
"""
Get information about the collection.
Returns:
Dict: Collection information including name and configuration
"""
info = self.collection.describe()
return {
"name": info.name,
"count": info.vectors,
"dimension": info.dimension,
"index": {"method": info.index_method, "metric": info.distance_metric},
}
def list(self, filters: Optional[dict] = None, limit: int = 100) -> List[OutputData]:
"""
List vectors in the collection.
Args:
filters (Dict, optional): Filters to apply
limit (int, optional): Maximum number of results to return. Defaults to 100.
Returns:
List[OutputData]: List of vectors
"""
filters = self._preprocess_filters(filters)
query = [0] * self.embedding_model_dims
ids = self.collection.query(
data=query, limit=limit, filters=filters, include_metadata=True, include_value=False
)
ids = [id[0] for id in ids]
records = self.collection.fetch(ids=ids)
return [[OutputData(id=str(record[0]), score=None, payload=record[2]) for record in records]]
def reset(self):
"""Reset the index by deleting and recreating it."""
logger.warning(f"Resetting index {self.collection_name}...")
self.delete_col()
self.create_col(self.embedding_model_dims)