From 91abc03880b7fbbeb5c52efb97755178c436d409 Mon Sep 17 00:00:00 2001 From: ytkimirti Date: Wed, 9 Apr 2025 07:36:07 +0300 Subject: [PATCH] Add Upstash Vector support (#2493) --- Makefile | 3 +- docs/components/vectordbs/config.mdx | 2 +- .../vectordbs/dbs/upstash-vector.mdx | 70 ++++ docs/components/vectordbs/overview.mdx | 1 + mem0/configs/vector_stores/upstash_vector.py | 36 ++ mem0/embeddings/mock.py | 11 + mem0/memory/main.py | 48 ++- mem0/utils/factory.py | 7 +- mem0/vector_stores/configs.py | 3 +- mem0/vector_stores/upstash_vector.py | 287 +++++++++++++ tests/vector_stores/test_upstash_vector.py | 384 ++++++++++++++++++ 11 files changed, 840 insertions(+), 12 deletions(-) create mode 100644 docs/components/vectordbs/dbs/upstash-vector.mdx create mode 100644 mem0/configs/vector_stores/upstash_vector.py create mode 100644 mem0/embeddings/mock.py create mode 100644 mem0/vector_stores/upstash_vector.py create mode 100644 tests/vector_stores/test_upstash_vector.py diff --git a/Makefile b/Makefile index aeac2993..c2189736 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,8 @@ install: install_all: poetry install poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \ - google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu langchain-community + google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu langchain-community \ + upstash-vector # Format code with ruff format: diff --git a/docs/components/vectordbs/config.mdx b/docs/components/vectordbs/config.mdx index 4ce16fae..a36e5579 100644 --- a/docs/components/vectordbs/config.mdx +++ b/docs/components/vectordbs/config.mdx @@ -8,7 +8,7 @@ iconType: "solid" The `config` is defined as an object with two main keys: - `vector_store`: Specifies the vector database provider and its configuration - - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus","azure_ai_search", "vertex_ai_vector_search") + - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus", "upstash_vector", "azure_ai_search", "vertex_ai_vector_search") - `config`: A nested dictionary containing provider-specific settings diff --git a/docs/components/vectordbs/dbs/upstash-vector.mdx b/docs/components/vectordbs/dbs/upstash-vector.mdx new file mode 100644 index 00000000..c4536d90 --- /dev/null +++ b/docs/components/vectordbs/dbs/upstash-vector.mdx @@ -0,0 +1,70 @@ +[Upstash Vector](https://upstash.com/docs/vector) is a serverless vector database with built-in embedding models. + +### Usage with Upstash embeddings + +You can enable the built-in embedding models by setting `enable_embeddings` to `True`. This allows you to use Upstash's embedding models for vectorization. + +```python +import os +from mem0 import Memory + +os.environ["UPSTASH_VECTOR_REST_URL"] = "..." +os.environ["UPSTASH_VECTOR_REST_TOKEN"] = "..." + +config = { + "vector_store": { + "provider": "upstash_vector", + "enable_embeddings": True, + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + + + Setting `enable_embeddings` to `True` will bypass any external embedding provider you have configured. + + +### Usage with external embedding providers + +```python +import os +from mem0 import Memory + +os.environ["OPENAI_API_KEY"] = "..." +os.environ["UPSTASH_VECTOR_REST_URL"] = "..." +os.environ["UPSTASH_VECTOR_REST_TOKEN"] = "..." + +config = { + "vector_store": { + "provider": "upstash_vector", + }, + "embedder": { + "provider": "openai", + "config": { + "model": "text-embedding-3-large" + }, + } +} + +m = Memory.from_config(config) +m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"}) +``` + +### Config + +Here are the parameters available for configuring Upstash Vector: + +| Parameter | Description | Default Value | +| ------------------- | ---------------------------------- | ------------- | +| `url` | URL for the Upstash Vector index | `None` | +| `token` | Token for the Upstash Vector index | `None` | +| `client` | An `upstash_vector.Index` instance | `None` | +| `collection_name` | The default namespace used | `""` | +| `enable_embeddings` | Whether to use Upstash embeddings | `False` | + + + When `url` and `token` are not provided, the `UPSTASH_VECTOR_REST_URL` and + `UPSTASH_VECTOR_REST_TOKEN` environment variables are used. + diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx index b9b451b5..3e60d34d 100644 --- a/docs/components/vectordbs/overview.mdx +++ b/docs/components/vectordbs/overview.mdx @@ -18,6 +18,7 @@ See the list of supported vector databases below. + diff --git a/mem0/configs/vector_stores/upstash_vector.py b/mem0/configs/vector_stores/upstash_vector.py new file mode 100644 index 00000000..b7c4d14f --- /dev/null +++ b/mem0/configs/vector_stores/upstash_vector.py @@ -0,0 +1,36 @@ +import os +from typing import Any, ClassVar, Dict, Optional + +from pydantic import BaseModel, Field, model_validator + +try: + from upstash_vector import Index +except ImportError: + raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.") + + +class UpstashVectorConfig(BaseModel): + Index: ClassVar[type] = Index + + url: Optional[str] = Field(None, description="URL for Upstash Vector index") + token: Optional[str] = Field(None, description="Token for Upstash Vector index") + client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance") + collection_name: str = Field("mem0", description="Namespace to use for the index") + enable_embeddings: bool = Field( + False, description="Whether to use built-in upstash embeddings or not. Default is True." + ) + + @model_validator(mode="before") + @classmethod + def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: + client = values.get("client") + url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL") + token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN") + + if not client and not (url and token): + raise ValueError("Either a client or URL and token must be provided.") + return values + + model_config = { + "arbitrary_types_allowed": True, + } diff --git a/mem0/embeddings/mock.py b/mem0/embeddings/mock.py new file mode 100644 index 00000000..0e411d79 --- /dev/null +++ b/mem0/embeddings/mock.py @@ -0,0 +1,11 @@ +from typing import Literal, Optional + +from mem0.embeddings.base import EmbeddingBase + + +class MockEmbeddings(EmbeddingBase): + def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): + """ + Generate a mock embedding with dimension of 10. + """ + return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] diff --git a/mem0/memory/main.py b/mem0/memory/main.py index 81710c29..73801995 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -12,12 +12,20 @@ from pydantic import ValidationError from mem0.configs.base import MemoryConfig, MemoryItem from mem0.configs.enums import MemoryType -from mem0.configs.prompts import PROCEDURAL_MEMORY_SYSTEM_PROMPT, get_update_memory_messages +from mem0.configs.prompts import ( + PROCEDURAL_MEMORY_SYSTEM_PROMPT, + get_update_memory_messages, +) from mem0.memory.base import MemoryBase from mem0.memory.setup import setup_config from mem0.memory.storage import SQLiteManager from mem0.memory.telemetry import capture_event -from mem0.memory.utils import get_fact_retrieval_messages, parse_messages, parse_vision_messages, remove_code_blocks +from mem0.memory.utils import ( + get_fact_retrieval_messages, + parse_messages, + parse_vision_messages, + remove_code_blocks, +) from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory # Setup user config @@ -32,7 +40,11 @@ class Memory(MemoryBase): self.custom_fact_extraction_prompt = self.config.custom_fact_extraction_prompt self.custom_update_memory_prompt = self.config.custom_update_memory_prompt - self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config) + self.embedding_model = EmbedderFactory.create( + self.config.embedder.provider, + self.config.embedder.config, + self.config.vector_store.config, + ) self.vector_store = VectorStoreFactory.create( self.config.vector_store.provider, self.config.vector_store.config ) @@ -260,7 +272,9 @@ class Memory(MemoryBase): continue elif resp.get("event") == "ADD": memory_id = self._create_memory( - data=resp.get("text"), existing_embeddings=new_message_embeddings, metadata=metadata + data=resp.get("text"), + existing_embeddings=new_message_embeddings, + metadata=metadata, ) returned_memories.append( { @@ -300,7 +314,11 @@ class Memory(MemoryBase): except Exception as e: logging.error(f"Error in new_memories_with_actions: {e}") - capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())}) + capture_event( + "mem0.add", + self, + {"version": self.api_version, "keys": list(filters.keys())}, + ) return returned_memories @@ -342,7 +360,16 @@ class Memory(MemoryBase): ).model_dump(exclude={"score"}) # Add metadata if there are additional keys - excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"} + excluded_keys = { + "user_id", + "agent_id", + "run_id", + "hash", + "data", + "created_at", + "updated_at", + "id", + } additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys} if additional_metadata: memory_item["metadata"] = additional_metadata @@ -631,7 +658,7 @@ class Memory(MemoryBase): prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None. """ try: - from langchain_core.messages.utils import convert_to_messages # type: ignore + pass except Exception: logger.error( "Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory." @@ -643,7 +670,10 @@ class Memory(MemoryBase): parsed_messages = [ {"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT}, *messages, - {"role": "user", "content": "Create procedural memory of the above conversation."}, + { + "role": "user", + "content": "Create procedural memory of the above conversation.", + }, ] try: @@ -728,7 +758,9 @@ class Memory(MemoryBase): self.vector_store = VectorStoreFactory.create( self.config.vector_store.provider, self.config.vector_store.config ) + print("before dbreset") self.db.reset() + print("after dbreset") capture_event("mem0.reset", self) def chat(self, query): diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 9e10dca1..4126cc15 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -1,7 +1,9 @@ import importlib +from typing import Optional from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.llms.base import BaseLlmConfig +from mem0.embeddings.mock import MockEmbeddings def load_class(class_type): @@ -54,7 +56,9 @@ class EmbedderFactory: } @classmethod - def create(cls, provider_name, config): + def create(cls, provider_name, config, vector_config: Optional[dict]): + if provider_name == "upstash_vector" and vector_config and vector_config.enable_embeddings: + return MockEmbeddings() class_type = cls.provider_to_class.get(provider_name) if class_type: embedder_instance = load_class(class_type) @@ -70,6 +74,7 @@ class VectorStoreFactory: "chroma": "mem0.vector_stores.chroma.ChromaDB", "pgvector": "mem0.vector_stores.pgvector.PGVector", "milvus": "mem0.vector_stores.milvus.MilvusDB", + "upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector", "azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch", "pinecone": "mem0.vector_stores.pinecone.PineconeDB", "redis": "mem0.vector_stores.redis.RedisDB", diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index 649dc2d5..0d1f9f7c 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, model_validator class VectorStoreConfig(BaseModel): provider: str = Field( - description="Provider of the vector store (e.g., 'qdrant', 'chroma')", + description="Provider of the vector store (e.g., 'qdrant', 'chroma', 'upstash_vector')", default="qdrant", ) config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None) @@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel): "pgvector": "PGVectorConfig", "pinecone": "PineconeConfig", "milvus": "MilvusDBConfig", + "upstash_vector": "UpstashVectorConfig", "azure_ai_search": "AzureAISearchConfig", "redis": "RedisDBConfig", "elasticsearch": "ElasticsearchConfig", diff --git a/mem0/vector_stores/upstash_vector.py b/mem0/vector_stores/upstash_vector.py new file mode 100644 index 00000000..66a010db --- /dev/null +++ b/mem0/vector_stores/upstash_vector.py @@ -0,0 +1,287 @@ +import logging +from typing import Dict, List, Optional + +from pydantic import BaseModel + +from mem0.vector_stores.base import VectorStoreBase + +try: + from upstash_vector import Index +except ImportError: + raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.") + + +logger = logging.getLogger(__name__) + + +class OutputData(BaseModel): + id: Optional[str] # memory id + score: Optional[float] # is None for `get` method + payload: Optional[Dict] # metadata + + +class UpstashVector(VectorStoreBase): + def __init__( + self, + collection_name: str, + url: Optional[str] = None, + token: Optional[str] = None, + client: Optional[Index] = None, + enable_embeddings: bool = False, + ): + """ + Initialize the UpstashVector vector store. + + Args: + url (str, optional): URL for Upstash Vector index. Defaults to None. + token (int, optional): Token for Upstash Vector index. Defaults to None. + client (Index, optional): Existing `upstash_vector.Index` client instance. Defaults to None. + namespace (str, optional): Default namespace for the index. Defaults to None. + """ + if client: + self.client = client + elif url and token: + self.client = Index(url, token) + else: + raise ValueError("Either a client or URL and token must be provided.") + + self.collection_name = collection_name + + self.enable_embeddings = enable_embeddings + + def insert( + self, + vectors: List[list], + payloads: Optional[List[Dict]] = None, + ids: Optional[List[str]] = None, + ): + """ + Insert vectors + + Args: + vectors (list): List of vectors to insert. + payloads (list, optional): List of payloads corresponding to vectors. These will be passed as metadatas to the Upstash Vector client. Defaults to None. + ids (list, optional): List of IDs corresponding to vectors. Defaults to None. + """ + logger.info(f"Inserting {len(vectors)} vectors into namespace {self.collection_name}") + + if self.enable_embeddings: + if not payloads or any("data" not in m or m["data"] is None for m in payloads): + raise ValueError("When embeddings are enabled, all payloads must contain a 'data' field.") + processed_vectors = [ + { + "id": ids[i] if ids else None, + "data": payloads[i]["data"], + "metadata": payloads[i], + } + for i, v in enumerate(vectors) + ] + else: + processed_vectors = [ + { + "id": ids[i] if ids else None, + "vector": vectors[i], + "metadata": payloads[i] if payloads else None, + } + for i, v in enumerate(vectors) + ] + + self.client.upsert( + vectors=processed_vectors, + namespace=self.collection_name, + ) + + def _stringify(self, x): + return f'"{x}"' if isinstance(x, str) else x + + def search( + self, + query: str, + vectors: List[list], + limit: int = 5, + filters: Optional[Dict] = None, + ) -> List[OutputData]: + """ + Search for similar vectors. + + Args: + query (list): Query vector. + limit (int, optional): Number of results to return. Defaults to 5. + filters (Dict, optional): Filters to apply to the search. + + Returns: + List[OutputData]: Search results. + """ + + filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None + + response = [] + + if self.enable_embeddings: + response = self.client.query( + data=query, + top_k=limit, + filter=filters_str or "", + include_metadata=True, + namespace=self.collection_name, + ) + else: + queries = [ + { + "vector": v, + "top_k": limit, + "filter": filters_str or "", + "include_metadata": True, + "namespace": self.collection_name, + } + for v in vectors + ] + responses = self.client.query_many(queries=queries) + # flatten + response = [res for res_list in responses for res in res_list] + + return [ + OutputData( + id=res.id, + score=res.score, + payload=res.metadata, + ) + for res in response + ] + + def delete(self, vector_id: int): + """ + Delete a vector by ID. + + Args: + vector_id (int): ID of the vector to delete. + """ + self.client.delete( + ids=[str(vector_id)], + namespace=self.collection_name, + ) + + def update( + self, + vector_id: int, + vector: Optional[list] = None, + payload: Optional[dict] = None, + ): + """ + Update a vector and its payload. + + Args: + vector_id (int): ID of the vector to update. + vector (list, optional): Updated vector. Defaults to None. + payload (dict, optional): Updated payload. Defaults to None. + """ + self.client.update( + id=str(vector_id), + vector=vector, + data=payload.get("data") if payload else None, + metadata=payload, + namespace=self.collection_name, + ) + + def get(self, vector_id: int) -> Optional[OutputData]: + """ + Retrieve a vector by ID. + + Args: + vector_id (int): ID of the vector to retrieve. + + Returns: + dict: Retrieved vector. + """ + response = self.client.fetch( + ids=[str(vector_id)], + namespace=self.collection_name, + include_metadata=True, + ) + if len(response) == 0: + return None + vector = response[0] + if not vector: + return None + return OutputData(id=vector.id, score=None, payload=vector.metadata) + + def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[List[OutputData]]: + """ + List all memories. + Args: + filters (Dict, optional): Filters to apply to the search. Defaults to None. + limit (int, optional): Number of results to return. Defaults to 100. + Returns: + List[OutputData]: Search results. + """ + filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None + + info = self.client.info() + ns_info = info.namespaces.get(self.collection_name) + + if not ns_info or ns_info.vector_count == 0: + return [[]] + + random_vector = [1.0] * self.client.info().dimension + + results, query = self.client.resumable_query( + vector=random_vector, + filter=filters_str or "", + include_metadata=True, + namespace=self.collection_name, + top_k=100, + ) + with query: + while True: + if len(results) >= limit: + break + res = query.fetch_next(100) + if not res: + break + results.extend(res) + + parsed_result = [ + OutputData( + id=res.id, + score=res.score, + payload=res.metadata, + ) + for res in results + ] + return [parsed_result] + + def create_col(self, name, vector_size, distance): + """ + Upstash Vector has namespaces instead of collections. A namespace is created when the first vector is inserted. + + This method is a placeholder to maintain the interface. + """ + pass + + def list_cols(self) -> List[str]: + """ + Lists all namespaces in the Upstash Vector index. + Returns: + List[str]: List of namespaces. + """ + return self.client.list_namespaces() + + def delete_col(self): + """ + Delete the namespace and all vectors in it. + """ + self.client.reset(namespace=self.collection_name) + pass + + def col_info(self): + """ + Return general information about the Upstash Vector index. + + - Total number of vectors across all namespaces + - Total number of vectors waiting to be indexed across all namespaces + - Total size of the index on disk in bytes + - Vector dimension + - Similarity function used + - Per-namespace vector and pending vector counts + """ + return self.client.info() diff --git a/tests/vector_stores/test_upstash_vector.py b/tests/vector_stores/test_upstash_vector.py new file mode 100644 index 00000000..e5a38846 --- /dev/null +++ b/tests/vector_stores/test_upstash_vector.py @@ -0,0 +1,384 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional +from unittest.mock import MagicMock, call, patch + +import pytest + +from mem0.vector_stores.upstash_vector import UpstashVector + + +@dataclass +class QueryResult: + id: str + score: Optional[float] + vector: Optional[List[float]] = None + metadata: Optional[Dict] = None + data: Optional[str] = None + + +@pytest.fixture +def mock_index(): + with patch("upstash_vector.Index") as mock_index: + yield mock_index + + +@pytest.fixture +def upstash_instance(mock_index): + return UpstashVector(client=mock_index.return_value, collection_name="ns") + + +@pytest.fixture +def upstash_instance_with_embeddings(mock_index): + return UpstashVector( + client=mock_index.return_value, collection_name="ns", enable_embeddings=True + ) + + +def test_insert_vectors(upstash_instance, mock_index): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + ids = ["id1", "id2"] + + upstash_instance.insert(vectors=vectors, payloads=payloads, ids=ids) + + upstash_instance.client.upsert.assert_called_once_with( + vectors=[ + {"id": "id1", "vector": [0.1, 0.2, 0.3], "metadata": {"name": "vector1"}}, + {"id": "id2", "vector": [0.4, 0.5, 0.6], "metadata": {"name": "vector2"}}, + ], + namespace="ns", + ) + + +def test_search_vectors(upstash_instance, mock_index): + mock_result = [ + QueryResult( + id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None + ), + QueryResult( + id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None + ), + ] + + upstash_instance.client.query_many.return_value = [mock_result] + + vectors = [[0.1, 0.2, 0.3]] + results = upstash_instance.search( + query="hello world", + vectors=vectors, + limit=2, + filters={"age": 30, "name": "John"}, + ) + + upstash_instance.client.query_many.assert_called_once_with( + queries=[ + { + "vector": vectors[0], + "top_k": 2, + "namespace": "ns", + "include_metadata": True, + "filter": 'age = 30 AND name = "John"', + } + ] + ) + + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.1 + assert results[0].payload == {"name": "vector1"} + + +def test_delete_vector(upstash_instance): + vector_id = "id1" + + upstash_instance.delete(vector_id=vector_id) + + upstash_instance.client.delete.assert_called_once_with( + ids=[vector_id], namespace="ns" + ) + + +def test_update_vector(upstash_instance): + vector_id = "id1" + new_vector = [0.7, 0.8, 0.9] + new_payload = {"name": "updated_vector"} + + upstash_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload) + + upstash_instance.client.update.assert_called_once_with( + id="id1", + vector=new_vector, + data=None, + metadata={"name": "updated_vector"}, + namespace="ns", + ) + + +def test_get_vector(upstash_instance): + mock_result = [ + QueryResult( + id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None + ) + ] + upstash_instance.client.fetch.return_value = mock_result + + result = upstash_instance.get(vector_id="id1") + + upstash_instance.client.fetch.assert_called_once_with( + ids=["id1"], namespace="ns", include_metadata=True + ) + + assert result.id == "id1" + assert result.payload == {"name": "vector1"} + + +def test_list_vectors(upstash_instance): + mock_result = [ + QueryResult( + id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None + ), + QueryResult( + id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None + ), + QueryResult( + id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None + ), + ] + handler = MagicMock() + + upstash_instance.client.info.return_value.dimension = 10 + upstash_instance.client.resumable_query.return_value = (mock_result[0:1], handler) + handler.fetch_next.side_effect = [mock_result[1:2], mock_result[2:3], []] + + filters = {"age": 30, "name": "John"} + print("filters", filters) + [results] = upstash_instance.list(filters=filters, limit=15) + + upstash_instance.client.info.return_value = { + "dimension": 10, + } + + upstash_instance.client.resumable_query.assert_called_once_with( + vector=[1.0] * 10, + filter='age = 30 AND name = "John"', + include_metadata=True, + namespace="ns", + top_k=100, + ) + + handler.fetch_next.assert_has_calls([call(100), call(100), call(100)]) + handler.__exit__.assert_called_once() + + assert len(results) == len(mock_result) + assert results[0].id == "id1" + assert results[0].payload == {"name": "vector1"} + + +def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [ + {"name": "vector1", "data": "data1"}, + {"name": "vector2", "data": "data2"}, + ] + ids = ["id1", "id2"] + + upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids) + + upstash_instance_with_embeddings.client.upsert.assert_called_once_with( + vectors=[ + { + "id": "id1", + # Uses the data field instead of using vectors + "data": "data1", + "metadata": {"name": "vector1", "data": "data1"}, + }, + { + "id": "id2", + "data": "data2", + "metadata": {"name": "vector2", "data": "data2"}, + }, + ], + namespace="ns", + ) + + +def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index): + mock_result = [ + QueryResult( + id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1" + ), + QueryResult( + id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2" + ), + ] + + upstash_instance_with_embeddings.client.query.return_value = mock_result + + results = upstash_instance_with_embeddings.search( + query="hello world", + vectors=[], + limit=2, + filters={"age": 30, "name": "John"}, + ) + + upstash_instance_with_embeddings.client.query.assert_called_once_with( + # Uses the data field instead of using vectors + data="hello world", + top_k=2, + filter='age = 30 AND name = "John"', + include_metadata=True, + namespace="ns", + ) + + assert len(results) == 2 + assert results[0].id == "id1" + assert results[0].score == 0.1 + assert results[0].payload == {"name": "vector1"} + + +def test_update_vector_with_embeddings(upstash_instance_with_embeddings): + vector_id = "id1" + new_payload = {"name": "updated_vector", "data": "updated_data"} + + upstash_instance_with_embeddings.update(vector_id=vector_id, payload=new_payload) + + upstash_instance_with_embeddings.client.update.assert_called_once_with( + id="id1", + vector=None, + data="updated_data", + metadata={"name": "updated_vector", "data": "updated_data"}, + namespace="ns", + ) + + +def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embeddings): + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"name": "vector1"}] # Missing data field + ids = ["id1"] + + with pytest.raises( + ValueError, + match="When embeddings are enabled, all payloads must contain a 'data' field", + ): + upstash_instance_with_embeddings.insert( + vectors=vectors, payloads=payloads, ids=ids + ) + + +def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings): + # Should still work, data is not required for update + vector_id = "id1" + new_payload = {"name": "updated_vector"} # Missing data field + + upstash_instance_with_embeddings.update(vector_id=vector_id, payload=new_payload) + + upstash_instance_with_embeddings.client.update.assert_called_once_with( + id="id1", + vector=None, + data=None, + metadata={"name": "updated_vector"}, + namespace="ns", + ) + + +def test_list_cols(upstash_instance): + mock_namespaces = ["ns1", "ns2", "ns3"] + upstash_instance.client.list_namespaces.return_value = mock_namespaces + + result = upstash_instance.list_cols() + + upstash_instance.client.list_namespaces.assert_called_once() + assert result == mock_namespaces + + +def test_delete_col(upstash_instance): + upstash_instance.delete_col() + upstash_instance.client.reset.assert_called_once_with(namespace="ns") + + +def test_col_info(upstash_instance): + mock_info = { + "dimension": 10, + "total_vectors": 100, + "pending_vectors": 0, + "disk_size": 1024, + } + upstash_instance.client.info.return_value = mock_info + + result = upstash_instance.col_info() + + upstash_instance.client.info.assert_called_once() + assert result == mock_info + + +def test_get_vector_not_found(upstash_instance): + upstash_instance.client.fetch.return_value = [] + + result = upstash_instance.get(vector_id="nonexistent") + + upstash_instance.client.fetch.assert_called_once_with( + ids=["nonexistent"], namespace="ns", include_metadata=True + ) + assert result is None + + +def test_search_vectors_empty_filters(upstash_instance): + mock_result = [ + QueryResult( + id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None + ) + ] + upstash_instance.client.query_many.return_value = [mock_result] + + vectors = [[0.1, 0.2, 0.3]] + results = upstash_instance.search( + query="hello world", + vectors=vectors, + limit=1, + filters=None, + ) + + upstash_instance.client.query_many.assert_called_once_with( + queries=[ + { + "vector": vectors[0], + "top_k": 1, + "namespace": "ns", + "include_metadata": True, + "filter": "", + } + ] + ) + + assert len(results) == 1 + assert results[0].id == "id1" + + +def test_insert_vectors_no_payloads(upstash_instance): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + ids = ["id1", "id2"] + + upstash_instance.insert(vectors=vectors, ids=ids) + + upstash_instance.client.upsert.assert_called_once_with( + vectors=[ + {"id": "id1", "vector": [0.1, 0.2, 0.3], "metadata": None}, + {"id": "id2", "vector": [0.4, 0.5, 0.6], "metadata": None}, + ], + namespace="ns", + ) + + +def test_insert_vectors_no_ids(upstash_instance): + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"name": "vector1"}, {"name": "vector2"}] + + upstash_instance.insert(vectors=vectors, payloads=payloads) + + upstash_instance.client.upsert.assert_called_once_with( + vectors=[ + {"id": None, "vector": [0.1, 0.2, 0.3], "metadata": {"name": "vector1"}}, + {"id": None, "vector": [0.4, 0.5, 0.6], "metadata": {"name": "vector2"}}, + ], + namespace="ns", + )