diff --git a/Makefile b/Makefile
index aeac2993..c2189736 100644
--- a/Makefile
+++ b/Makefile
@@ -13,7 +13,8 @@ install:
install_all:
poetry install
poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
- google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu langchain-community
+ google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu langchain-community \
+ upstash-vector
# Format code with ruff
format:
diff --git a/docs/components/vectordbs/config.mdx b/docs/components/vectordbs/config.mdx
index 4ce16fae..a36e5579 100644
--- a/docs/components/vectordbs/config.mdx
+++ b/docs/components/vectordbs/config.mdx
@@ -8,7 +8,7 @@ iconType: "solid"
The `config` is defined as an object with two main keys:
- `vector_store`: Specifies the vector database provider and its configuration
- - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus","azure_ai_search", "vertex_ai_vector_search")
+ - `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus", "upstash_vector", "azure_ai_search", "vertex_ai_vector_search")
- `config`: A nested dictionary containing provider-specific settings
diff --git a/docs/components/vectordbs/dbs/upstash-vector.mdx b/docs/components/vectordbs/dbs/upstash-vector.mdx
new file mode 100644
index 00000000..c4536d90
--- /dev/null
+++ b/docs/components/vectordbs/dbs/upstash-vector.mdx
@@ -0,0 +1,70 @@
+[Upstash Vector](https://upstash.com/docs/vector) is a serverless vector database with built-in embedding models.
+
+### Usage with Upstash embeddings
+
+You can enable the built-in embedding models by setting `enable_embeddings` to `True`. This allows you to use Upstash's embedding models for vectorization.
+
+```python
+import os
+from mem0 import Memory
+
+os.environ["UPSTASH_VECTOR_REST_URL"] = "..."
+os.environ["UPSTASH_VECTOR_REST_TOKEN"] = "..."
+
+config = {
+ "vector_store": {
+ "provider": "upstash_vector",
+ "enable_embeddings": True,
+ }
+}
+
+m = Memory.from_config(config)
+m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
+```
+
+
+ Setting `enable_embeddings` to `True` will bypass any external embedding provider you have configured.
+
+
+### Usage with external embedding providers
+
+```python
+import os
+from mem0 import Memory
+
+os.environ["OPENAI_API_KEY"] = "..."
+os.environ["UPSTASH_VECTOR_REST_URL"] = "..."
+os.environ["UPSTASH_VECTOR_REST_TOKEN"] = "..."
+
+config = {
+ "vector_store": {
+ "provider": "upstash_vector",
+ },
+ "embedder": {
+ "provider": "openai",
+ "config": {
+ "model": "text-embedding-3-large"
+ },
+ }
+}
+
+m = Memory.from_config(config)
+m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
+```
+
+### Config
+
+Here are the parameters available for configuring Upstash Vector:
+
+| Parameter | Description | Default Value |
+| ------------------- | ---------------------------------- | ------------- |
+| `url` | URL for the Upstash Vector index | `None` |
+| `token` | Token for the Upstash Vector index | `None` |
+| `client` | An `upstash_vector.Index` instance | `None` |
+| `collection_name` | The default namespace used | `""` |
+| `enable_embeddings` | Whether to use Upstash embeddings | `False` |
+
+
+ When `url` and `token` are not provided, the `UPSTASH_VECTOR_REST_URL` and
+ `UPSTASH_VECTOR_REST_TOKEN` environment variables are used.
+
diff --git a/docs/components/vectordbs/overview.mdx b/docs/components/vectordbs/overview.mdx
index b9b451b5..3e60d34d 100644
--- a/docs/components/vectordbs/overview.mdx
+++ b/docs/components/vectordbs/overview.mdx
@@ -18,6 +18,7 @@ See the list of supported vector databases below.
+
diff --git a/mem0/configs/vector_stores/upstash_vector.py b/mem0/configs/vector_stores/upstash_vector.py
new file mode 100644
index 00000000..b7c4d14f
--- /dev/null
+++ b/mem0/configs/vector_stores/upstash_vector.py
@@ -0,0 +1,36 @@
+import os
+from typing import Any, ClassVar, Dict, Optional
+
+from pydantic import BaseModel, Field, model_validator
+
+try:
+ from upstash_vector import Index
+except ImportError:
+ raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
+
+
+class UpstashVectorConfig(BaseModel):
+ Index: ClassVar[type] = Index
+
+ url: Optional[str] = Field(None, description="URL for Upstash Vector index")
+ token: Optional[str] = Field(None, description="Token for Upstash Vector index")
+ client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance")
+ collection_name: str = Field("mem0", description="Namespace to use for the index")
+ enable_embeddings: bool = Field(
+ False, description="Whether to use built-in upstash embeddings or not. Default is True."
+ )
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ client = values.get("client")
+ url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL")
+ token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN")
+
+ if not client and not (url and token):
+ raise ValueError("Either a client or URL and token must be provided.")
+ return values
+
+ model_config = {
+ "arbitrary_types_allowed": True,
+ }
diff --git a/mem0/embeddings/mock.py b/mem0/embeddings/mock.py
new file mode 100644
index 00000000..0e411d79
--- /dev/null
+++ b/mem0/embeddings/mock.py
@@ -0,0 +1,11 @@
+from typing import Literal, Optional
+
+from mem0.embeddings.base import EmbeddingBase
+
+
+class MockEmbeddings(EmbeddingBase):
+ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
+ """
+ Generate a mock embedding with dimension of 10.
+ """
+ return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
diff --git a/mem0/memory/main.py b/mem0/memory/main.py
index 81710c29..73801995 100644
--- a/mem0/memory/main.py
+++ b/mem0/memory/main.py
@@ -12,12 +12,20 @@ from pydantic import ValidationError
from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.enums import MemoryType
-from mem0.configs.prompts import PROCEDURAL_MEMORY_SYSTEM_PROMPT, get_update_memory_messages
+from mem0.configs.prompts import (
+ PROCEDURAL_MEMORY_SYSTEM_PROMPT,
+ get_update_memory_messages,
+)
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
-from mem0.memory.utils import get_fact_retrieval_messages, parse_messages, parse_vision_messages, remove_code_blocks
+from mem0.memory.utils import (
+ get_fact_retrieval_messages,
+ parse_messages,
+ parse_vision_messages,
+ remove_code_blocks,
+)
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
# Setup user config
@@ -32,7 +40,11 @@ class Memory(MemoryBase):
self.custom_fact_extraction_prompt = self.config.custom_fact_extraction_prompt
self.custom_update_memory_prompt = self.config.custom_update_memory_prompt
- self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
+ self.embedding_model = EmbedderFactory.create(
+ self.config.embedder.provider,
+ self.config.embedder.config,
+ self.config.vector_store.config,
+ )
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
@@ -260,7 +272,9 @@ class Memory(MemoryBase):
continue
elif resp.get("event") == "ADD":
memory_id = self._create_memory(
- data=resp.get("text"), existing_embeddings=new_message_embeddings, metadata=metadata
+ data=resp.get("text"),
+ existing_embeddings=new_message_embeddings,
+ metadata=metadata,
)
returned_memories.append(
{
@@ -300,7 +314,11 @@ class Memory(MemoryBase):
except Exception as e:
logging.error(f"Error in new_memories_with_actions: {e}")
- capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
+ capture_event(
+ "mem0.add",
+ self,
+ {"version": self.api_version, "keys": list(filters.keys())},
+ )
return returned_memories
@@ -342,7 +360,16 @@ class Memory(MemoryBase):
).model_dump(exclude={"score"})
# Add metadata if there are additional keys
- excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"}
+ excluded_keys = {
+ "user_id",
+ "agent_id",
+ "run_id",
+ "hash",
+ "data",
+ "created_at",
+ "updated_at",
+ "id",
+ }
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
if additional_metadata:
memory_item["metadata"] = additional_metadata
@@ -631,7 +658,7 @@ class Memory(MemoryBase):
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
"""
try:
- from langchain_core.messages.utils import convert_to_messages # type: ignore
+ pass
except Exception:
logger.error(
"Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory."
@@ -643,7 +670,10 @@ class Memory(MemoryBase):
parsed_messages = [
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
*messages,
- {"role": "user", "content": "Create procedural memory of the above conversation."},
+ {
+ "role": "user",
+ "content": "Create procedural memory of the above conversation.",
+ },
]
try:
@@ -728,7 +758,9 @@ class Memory(MemoryBase):
self.vector_store = VectorStoreFactory.create(
self.config.vector_store.provider, self.config.vector_store.config
)
+ print("before dbreset")
self.db.reset()
+ print("after dbreset")
capture_event("mem0.reset", self)
def chat(self, query):
diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py
index 9e10dca1..4126cc15 100644
--- a/mem0/utils/factory.py
+++ b/mem0/utils/factory.py
@@ -1,7 +1,9 @@
import importlib
+from typing import Optional
from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.configs.llms.base import BaseLlmConfig
+from mem0.embeddings.mock import MockEmbeddings
def load_class(class_type):
@@ -54,7 +56,9 @@ class EmbedderFactory:
}
@classmethod
- def create(cls, provider_name, config):
+ def create(cls, provider_name, config, vector_config: Optional[dict]):
+ if provider_name == "upstash_vector" and vector_config and vector_config.enable_embeddings:
+ return MockEmbeddings()
class_type = cls.provider_to_class.get(provider_name)
if class_type:
embedder_instance = load_class(class_type)
@@ -70,6 +74,7 @@ class VectorStoreFactory:
"chroma": "mem0.vector_stores.chroma.ChromaDB",
"pgvector": "mem0.vector_stores.pgvector.PGVector",
"milvus": "mem0.vector_stores.milvus.MilvusDB",
+ "upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector",
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
"redis": "mem0.vector_stores.redis.RedisDB",
diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py
index 649dc2d5..0d1f9f7c 100644
--- a/mem0/vector_stores/configs.py
+++ b/mem0/vector_stores/configs.py
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, model_validator
class VectorStoreConfig(BaseModel):
provider: str = Field(
- description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
+ description="Provider of the vector store (e.g., 'qdrant', 'chroma', 'upstash_vector')",
default="qdrant",
)
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
"pgvector": "PGVectorConfig",
"pinecone": "PineconeConfig",
"milvus": "MilvusDBConfig",
+ "upstash_vector": "UpstashVectorConfig",
"azure_ai_search": "AzureAISearchConfig",
"redis": "RedisDBConfig",
"elasticsearch": "ElasticsearchConfig",
diff --git a/mem0/vector_stores/upstash_vector.py b/mem0/vector_stores/upstash_vector.py
new file mode 100644
index 00000000..66a010db
--- /dev/null
+++ b/mem0/vector_stores/upstash_vector.py
@@ -0,0 +1,287 @@
+import logging
+from typing import Dict, List, Optional
+
+from pydantic import BaseModel
+
+from mem0.vector_stores.base import VectorStoreBase
+
+try:
+ from upstash_vector import Index
+except ImportError:
+ raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
+
+
+logger = logging.getLogger(__name__)
+
+
+class OutputData(BaseModel):
+ id: Optional[str] # memory id
+ score: Optional[float] # is None for `get` method
+ payload: Optional[Dict] # metadata
+
+
+class UpstashVector(VectorStoreBase):
+ def __init__(
+ self,
+ collection_name: str,
+ url: Optional[str] = None,
+ token: Optional[str] = None,
+ client: Optional[Index] = None,
+ enable_embeddings: bool = False,
+ ):
+ """
+ Initialize the UpstashVector vector store.
+
+ Args:
+ url (str, optional): URL for Upstash Vector index. Defaults to None.
+ token (int, optional): Token for Upstash Vector index. Defaults to None.
+ client (Index, optional): Existing `upstash_vector.Index` client instance. Defaults to None.
+ namespace (str, optional): Default namespace for the index. Defaults to None.
+ """
+ if client:
+ self.client = client
+ elif url and token:
+ self.client = Index(url, token)
+ else:
+ raise ValueError("Either a client or URL and token must be provided.")
+
+ self.collection_name = collection_name
+
+ self.enable_embeddings = enable_embeddings
+
+ def insert(
+ self,
+ vectors: List[list],
+ payloads: Optional[List[Dict]] = None,
+ ids: Optional[List[str]] = None,
+ ):
+ """
+ Insert vectors
+
+ Args:
+ vectors (list): List of vectors to insert.
+ payloads (list, optional): List of payloads corresponding to vectors. These will be passed as metadatas to the Upstash Vector client. Defaults to None.
+ ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
+ """
+ logger.info(f"Inserting {len(vectors)} vectors into namespace {self.collection_name}")
+
+ if self.enable_embeddings:
+ if not payloads or any("data" not in m or m["data"] is None for m in payloads):
+ raise ValueError("When embeddings are enabled, all payloads must contain a 'data' field.")
+ processed_vectors = [
+ {
+ "id": ids[i] if ids else None,
+ "data": payloads[i]["data"],
+ "metadata": payloads[i],
+ }
+ for i, v in enumerate(vectors)
+ ]
+ else:
+ processed_vectors = [
+ {
+ "id": ids[i] if ids else None,
+ "vector": vectors[i],
+ "metadata": payloads[i] if payloads else None,
+ }
+ for i, v in enumerate(vectors)
+ ]
+
+ self.client.upsert(
+ vectors=processed_vectors,
+ namespace=self.collection_name,
+ )
+
+ def _stringify(self, x):
+ return f'"{x}"' if isinstance(x, str) else x
+
+ def search(
+ self,
+ query: str,
+ vectors: List[list],
+ limit: int = 5,
+ filters: Optional[Dict] = None,
+ ) -> List[OutputData]:
+ """
+ Search for similar vectors.
+
+ Args:
+ query (list): Query vector.
+ limit (int, optional): Number of results to return. Defaults to 5.
+ filters (Dict, optional): Filters to apply to the search.
+
+ Returns:
+ List[OutputData]: Search results.
+ """
+
+ filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
+
+ response = []
+
+ if self.enable_embeddings:
+ response = self.client.query(
+ data=query,
+ top_k=limit,
+ filter=filters_str or "",
+ include_metadata=True,
+ namespace=self.collection_name,
+ )
+ else:
+ queries = [
+ {
+ "vector": v,
+ "top_k": limit,
+ "filter": filters_str or "",
+ "include_metadata": True,
+ "namespace": self.collection_name,
+ }
+ for v in vectors
+ ]
+ responses = self.client.query_many(queries=queries)
+ # flatten
+ response = [res for res_list in responses for res in res_list]
+
+ return [
+ OutputData(
+ id=res.id,
+ score=res.score,
+ payload=res.metadata,
+ )
+ for res in response
+ ]
+
+ def delete(self, vector_id: int):
+ """
+ Delete a vector by ID.
+
+ Args:
+ vector_id (int): ID of the vector to delete.
+ """
+ self.client.delete(
+ ids=[str(vector_id)],
+ namespace=self.collection_name,
+ )
+
+ def update(
+ self,
+ vector_id: int,
+ vector: Optional[list] = None,
+ payload: Optional[dict] = None,
+ ):
+ """
+ Update a vector and its payload.
+
+ Args:
+ vector_id (int): ID of the vector to update.
+ vector (list, optional): Updated vector. Defaults to None.
+ payload (dict, optional): Updated payload. Defaults to None.
+ """
+ self.client.update(
+ id=str(vector_id),
+ vector=vector,
+ data=payload.get("data") if payload else None,
+ metadata=payload,
+ namespace=self.collection_name,
+ )
+
+ def get(self, vector_id: int) -> Optional[OutputData]:
+ """
+ Retrieve a vector by ID.
+
+ Args:
+ vector_id (int): ID of the vector to retrieve.
+
+ Returns:
+ dict: Retrieved vector.
+ """
+ response = self.client.fetch(
+ ids=[str(vector_id)],
+ namespace=self.collection_name,
+ include_metadata=True,
+ )
+ if len(response) == 0:
+ return None
+ vector = response[0]
+ if not vector:
+ return None
+ return OutputData(id=vector.id, score=None, payload=vector.metadata)
+
+ def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[List[OutputData]]:
+ """
+ List all memories.
+ Args:
+ filters (Dict, optional): Filters to apply to the search. Defaults to None.
+ limit (int, optional): Number of results to return. Defaults to 100.
+ Returns:
+ List[OutputData]: Search results.
+ """
+ filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
+
+ info = self.client.info()
+ ns_info = info.namespaces.get(self.collection_name)
+
+ if not ns_info or ns_info.vector_count == 0:
+ return [[]]
+
+ random_vector = [1.0] * self.client.info().dimension
+
+ results, query = self.client.resumable_query(
+ vector=random_vector,
+ filter=filters_str or "",
+ include_metadata=True,
+ namespace=self.collection_name,
+ top_k=100,
+ )
+ with query:
+ while True:
+ if len(results) >= limit:
+ break
+ res = query.fetch_next(100)
+ if not res:
+ break
+ results.extend(res)
+
+ parsed_result = [
+ OutputData(
+ id=res.id,
+ score=res.score,
+ payload=res.metadata,
+ )
+ for res in results
+ ]
+ return [parsed_result]
+
+ def create_col(self, name, vector_size, distance):
+ """
+ Upstash Vector has namespaces instead of collections. A namespace is created when the first vector is inserted.
+
+ This method is a placeholder to maintain the interface.
+ """
+ pass
+
+ def list_cols(self) -> List[str]:
+ """
+ Lists all namespaces in the Upstash Vector index.
+ Returns:
+ List[str]: List of namespaces.
+ """
+ return self.client.list_namespaces()
+
+ def delete_col(self):
+ """
+ Delete the namespace and all vectors in it.
+ """
+ self.client.reset(namespace=self.collection_name)
+ pass
+
+ def col_info(self):
+ """
+ Return general information about the Upstash Vector index.
+
+ - Total number of vectors across all namespaces
+ - Total number of vectors waiting to be indexed across all namespaces
+ - Total size of the index on disk in bytes
+ - Vector dimension
+ - Similarity function used
+ - Per-namespace vector and pending vector counts
+ """
+ return self.client.info()
diff --git a/tests/vector_stores/test_upstash_vector.py b/tests/vector_stores/test_upstash_vector.py
new file mode 100644
index 00000000..e5a38846
--- /dev/null
+++ b/tests/vector_stores/test_upstash_vector.py
@@ -0,0 +1,384 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+from unittest.mock import MagicMock, call, patch
+
+import pytest
+
+from mem0.vector_stores.upstash_vector import UpstashVector
+
+
+@dataclass
+class QueryResult:
+ id: str
+ score: Optional[float]
+ vector: Optional[List[float]] = None
+ metadata: Optional[Dict] = None
+ data: Optional[str] = None
+
+
+@pytest.fixture
+def mock_index():
+ with patch("upstash_vector.Index") as mock_index:
+ yield mock_index
+
+
+@pytest.fixture
+def upstash_instance(mock_index):
+ return UpstashVector(client=mock_index.return_value, collection_name="ns")
+
+
+@pytest.fixture
+def upstash_instance_with_embeddings(mock_index):
+ return UpstashVector(
+ client=mock_index.return_value, collection_name="ns", enable_embeddings=True
+ )
+
+
+def test_insert_vectors(upstash_instance, mock_index):
+ vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
+ payloads = [{"name": "vector1"}, {"name": "vector2"}]
+ ids = ["id1", "id2"]
+
+ upstash_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
+
+ upstash_instance.client.upsert.assert_called_once_with(
+ vectors=[
+ {"id": "id1", "vector": [0.1, 0.2, 0.3], "metadata": {"name": "vector1"}},
+ {"id": "id2", "vector": [0.4, 0.5, 0.6], "metadata": {"name": "vector2"}},
+ ],
+ namespace="ns",
+ )
+
+
+def test_search_vectors(upstash_instance, mock_index):
+ mock_result = [
+ QueryResult(
+ id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
+ ),
+ QueryResult(
+ id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None
+ ),
+ ]
+
+ upstash_instance.client.query_many.return_value = [mock_result]
+
+ vectors = [[0.1, 0.2, 0.3]]
+ results = upstash_instance.search(
+ query="hello world",
+ vectors=vectors,
+ limit=2,
+ filters={"age": 30, "name": "John"},
+ )
+
+ upstash_instance.client.query_many.assert_called_once_with(
+ queries=[
+ {
+ "vector": vectors[0],
+ "top_k": 2,
+ "namespace": "ns",
+ "include_metadata": True,
+ "filter": 'age = 30 AND name = "John"',
+ }
+ ]
+ )
+
+ assert len(results) == 2
+ assert results[0].id == "id1"
+ assert results[0].score == 0.1
+ assert results[0].payload == {"name": "vector1"}
+
+
+def test_delete_vector(upstash_instance):
+ vector_id = "id1"
+
+ upstash_instance.delete(vector_id=vector_id)
+
+ upstash_instance.client.delete.assert_called_once_with(
+ ids=[vector_id], namespace="ns"
+ )
+
+
+def test_update_vector(upstash_instance):
+ vector_id = "id1"
+ new_vector = [0.7, 0.8, 0.9]
+ new_payload = {"name": "updated_vector"}
+
+ upstash_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
+
+ upstash_instance.client.update.assert_called_once_with(
+ id="id1",
+ vector=new_vector,
+ data=None,
+ metadata={"name": "updated_vector"},
+ namespace="ns",
+ )
+
+
+def test_get_vector(upstash_instance):
+ mock_result = [
+ QueryResult(
+ id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
+ )
+ ]
+ upstash_instance.client.fetch.return_value = mock_result
+
+ result = upstash_instance.get(vector_id="id1")
+
+ upstash_instance.client.fetch.assert_called_once_with(
+ ids=["id1"], namespace="ns", include_metadata=True
+ )
+
+ assert result.id == "id1"
+ assert result.payload == {"name": "vector1"}
+
+
+def test_list_vectors(upstash_instance):
+ mock_result = [
+ QueryResult(
+ id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
+ ),
+ QueryResult(
+ id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None
+ ),
+ QueryResult(
+ id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None
+ ),
+ ]
+ handler = MagicMock()
+
+ upstash_instance.client.info.return_value.dimension = 10
+ upstash_instance.client.resumable_query.return_value = (mock_result[0:1], handler)
+ handler.fetch_next.side_effect = [mock_result[1:2], mock_result[2:3], []]
+
+ filters = {"age": 30, "name": "John"}
+ print("filters", filters)
+ [results] = upstash_instance.list(filters=filters, limit=15)
+
+ upstash_instance.client.info.return_value = {
+ "dimension": 10,
+ }
+
+ upstash_instance.client.resumable_query.assert_called_once_with(
+ vector=[1.0] * 10,
+ filter='age = 30 AND name = "John"',
+ include_metadata=True,
+ namespace="ns",
+ top_k=100,
+ )
+
+ handler.fetch_next.assert_has_calls([call(100), call(100), call(100)])
+ handler.__exit__.assert_called_once()
+
+ assert len(results) == len(mock_result)
+ assert results[0].id == "id1"
+ assert results[0].payload == {"name": "vector1"}
+
+
+def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
+ vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
+ payloads = [
+ {"name": "vector1", "data": "data1"},
+ {"name": "vector2", "data": "data2"},
+ ]
+ ids = ["id1", "id2"]
+
+ upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids)
+
+ upstash_instance_with_embeddings.client.upsert.assert_called_once_with(
+ vectors=[
+ {
+ "id": "id1",
+ # Uses the data field instead of using vectors
+ "data": "data1",
+ "metadata": {"name": "vector1", "data": "data1"},
+ },
+ {
+ "id": "id2",
+ "data": "data2",
+ "metadata": {"name": "vector2", "data": "data2"},
+ },
+ ],
+ namespace="ns",
+ )
+
+
+def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
+ mock_result = [
+ QueryResult(
+ id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"
+ ),
+ QueryResult(
+ id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"
+ ),
+ ]
+
+ upstash_instance_with_embeddings.client.query.return_value = mock_result
+
+ results = upstash_instance_with_embeddings.search(
+ query="hello world",
+ vectors=[],
+ limit=2,
+ filters={"age": 30, "name": "John"},
+ )
+
+ upstash_instance_with_embeddings.client.query.assert_called_once_with(
+ # Uses the data field instead of using vectors
+ data="hello world",
+ top_k=2,
+ filter='age = 30 AND name = "John"',
+ include_metadata=True,
+ namespace="ns",
+ )
+
+ assert len(results) == 2
+ assert results[0].id == "id1"
+ assert results[0].score == 0.1
+ assert results[0].payload == {"name": "vector1"}
+
+
+def test_update_vector_with_embeddings(upstash_instance_with_embeddings):
+ vector_id = "id1"
+ new_payload = {"name": "updated_vector", "data": "updated_data"}
+
+ upstash_instance_with_embeddings.update(vector_id=vector_id, payload=new_payload)
+
+ upstash_instance_with_embeddings.client.update.assert_called_once_with(
+ id="id1",
+ vector=None,
+ data="updated_data",
+ metadata={"name": "updated_vector", "data": "updated_data"},
+ namespace="ns",
+ )
+
+
+def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embeddings):
+ vectors = [[0.1, 0.2, 0.3]]
+ payloads = [{"name": "vector1"}] # Missing data field
+ ids = ["id1"]
+
+ with pytest.raises(
+ ValueError,
+ match="When embeddings are enabled, all payloads must contain a 'data' field",
+ ):
+ upstash_instance_with_embeddings.insert(
+ vectors=vectors, payloads=payloads, ids=ids
+ )
+
+
+def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
+ # Should still work, data is not required for update
+ vector_id = "id1"
+ new_payload = {"name": "updated_vector"} # Missing data field
+
+ upstash_instance_with_embeddings.update(vector_id=vector_id, payload=new_payload)
+
+ upstash_instance_with_embeddings.client.update.assert_called_once_with(
+ id="id1",
+ vector=None,
+ data=None,
+ metadata={"name": "updated_vector"},
+ namespace="ns",
+ )
+
+
+def test_list_cols(upstash_instance):
+ mock_namespaces = ["ns1", "ns2", "ns3"]
+ upstash_instance.client.list_namespaces.return_value = mock_namespaces
+
+ result = upstash_instance.list_cols()
+
+ upstash_instance.client.list_namespaces.assert_called_once()
+ assert result == mock_namespaces
+
+
+def test_delete_col(upstash_instance):
+ upstash_instance.delete_col()
+ upstash_instance.client.reset.assert_called_once_with(namespace="ns")
+
+
+def test_col_info(upstash_instance):
+ mock_info = {
+ "dimension": 10,
+ "total_vectors": 100,
+ "pending_vectors": 0,
+ "disk_size": 1024,
+ }
+ upstash_instance.client.info.return_value = mock_info
+
+ result = upstash_instance.col_info()
+
+ upstash_instance.client.info.assert_called_once()
+ assert result == mock_info
+
+
+def test_get_vector_not_found(upstash_instance):
+ upstash_instance.client.fetch.return_value = []
+
+ result = upstash_instance.get(vector_id="nonexistent")
+
+ upstash_instance.client.fetch.assert_called_once_with(
+ ids=["nonexistent"], namespace="ns", include_metadata=True
+ )
+ assert result is None
+
+
+def test_search_vectors_empty_filters(upstash_instance):
+ mock_result = [
+ QueryResult(
+ id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
+ )
+ ]
+ upstash_instance.client.query_many.return_value = [mock_result]
+
+ vectors = [[0.1, 0.2, 0.3]]
+ results = upstash_instance.search(
+ query="hello world",
+ vectors=vectors,
+ limit=1,
+ filters=None,
+ )
+
+ upstash_instance.client.query_many.assert_called_once_with(
+ queries=[
+ {
+ "vector": vectors[0],
+ "top_k": 1,
+ "namespace": "ns",
+ "include_metadata": True,
+ "filter": "",
+ }
+ ]
+ )
+
+ assert len(results) == 1
+ assert results[0].id == "id1"
+
+
+def test_insert_vectors_no_payloads(upstash_instance):
+ vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
+ ids = ["id1", "id2"]
+
+ upstash_instance.insert(vectors=vectors, ids=ids)
+
+ upstash_instance.client.upsert.assert_called_once_with(
+ vectors=[
+ {"id": "id1", "vector": [0.1, 0.2, 0.3], "metadata": None},
+ {"id": "id2", "vector": [0.4, 0.5, 0.6], "metadata": None},
+ ],
+ namespace="ns",
+ )
+
+
+def test_insert_vectors_no_ids(upstash_instance):
+ vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
+ payloads = [{"name": "vector1"}, {"name": "vector2"}]
+
+ upstash_instance.insert(vectors=vectors, payloads=payloads)
+
+ upstash_instance.client.upsert.assert_called_once_with(
+ vectors=[
+ {"id": None, "vector": [0.1, 0.2, 0.3], "metadata": {"name": "vector1"}},
+ {"id": None, "vector": [0.4, 0.5, 0.6], "metadata": {"name": "vector2"}},
+ ],
+ namespace="ns",
+ )