Add Upstash Vector support (#2493)
This commit is contained in:
3
Makefile
3
Makefile
@@ -13,7 +13,8 @@ install:
|
|||||||
install_all:
|
install_all:
|
||||||
poetry install
|
poetry install
|
||||||
poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||||
google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu langchain-community
|
google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text faiss-cpu langchain-community \
|
||||||
|
upstash-vector
|
||||||
|
|
||||||
# Format code with ruff
|
# Format code with ruff
|
||||||
format:
|
format:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ iconType: "solid"
|
|||||||
|
|
||||||
The `config` is defined as an object with two main keys:
|
The `config` is defined as an object with two main keys:
|
||||||
- `vector_store`: Specifies the vector database provider and its configuration
|
- `vector_store`: Specifies the vector database provider and its configuration
|
||||||
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus","azure_ai_search", "vertex_ai_vector_search")
|
- `provider`: The name of the vector database (e.g., "chroma", "pgvector", "qdrant", "milvus", "upstash_vector", "azure_ai_search", "vertex_ai_vector_search")
|
||||||
- `config`: A nested dictionary containing provider-specific settings
|
- `config`: A nested dictionary containing provider-specific settings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
70
docs/components/vectordbs/dbs/upstash-vector.mdx
Normal file
70
docs/components/vectordbs/dbs/upstash-vector.mdx
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
[Upstash Vector](https://upstash.com/docs/vector) is a serverless vector database with built-in embedding models.
|
||||||
|
|
||||||
|
### Usage with Upstash embeddings
|
||||||
|
|
||||||
|
You can enable the built-in embedding models by setting `enable_embeddings` to `True`. This allows you to use Upstash's embedding models for vectorization.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["UPSTASH_VECTOR_REST_URL"] = "..."
|
||||||
|
os.environ["UPSTASH_VECTOR_REST_TOKEN"] = "..."
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "upstash_vector",
|
||||||
|
"enable_embeddings": True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
<Note>
|
||||||
|
Setting `enable_embeddings` to `True` will bypass any external embedding provider you have configured.
|
||||||
|
</Note>
|
||||||
|
|
||||||
|
### Usage with external embedding providers
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from mem0 import Memory
|
||||||
|
|
||||||
|
os.environ["OPENAI_API_KEY"] = "..."
|
||||||
|
os.environ["UPSTASH_VECTOR_REST_URL"] = "..."
|
||||||
|
os.environ["UPSTASH_VECTOR_REST_TOKEN"] = "..."
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"vector_store": {
|
||||||
|
"provider": "upstash_vector",
|
||||||
|
},
|
||||||
|
"embedder": {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {
|
||||||
|
"model": "text-embedding-3-large"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m = Memory.from_config(config)
|
||||||
|
m.add("Likes to play cricket on weekends", user_id="alice", metadata={"category": "hobbies"})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
Here are the parameters available for configuring Upstash Vector:
|
||||||
|
|
||||||
|
| Parameter | Description | Default Value |
|
||||||
|
| ------------------- | ---------------------------------- | ------------- |
|
||||||
|
| `url` | URL for the Upstash Vector index | `None` |
|
||||||
|
| `token` | Token for the Upstash Vector index | `None` |
|
||||||
|
| `client` | An `upstash_vector.Index` instance | `None` |
|
||||||
|
| `collection_name` | The default namespace used | `""` |
|
||||||
|
| `enable_embeddings` | Whether to use Upstash embeddings | `False` |
|
||||||
|
|
||||||
|
<Note>
|
||||||
|
When `url` and `token` are not provided, the `UPSTASH_VECTOR_REST_URL` and
|
||||||
|
`UPSTASH_VECTOR_REST_TOKEN` environment variables are used.
|
||||||
|
</Note>
|
||||||
@@ -18,6 +18,7 @@ See the list of supported vector databases below.
|
|||||||
<Card title="Qdrant" href="/components/vectordbs/dbs/qdrant"></Card>
|
<Card title="Qdrant" href="/components/vectordbs/dbs/qdrant"></Card>
|
||||||
<Card title="Chroma" href="/components/vectordbs/dbs/chroma"></Card>
|
<Card title="Chroma" href="/components/vectordbs/dbs/chroma"></Card>
|
||||||
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
||||||
|
<Card title="Upstash Vector" href="/components/vectordbs/dbs/upstash-vector"></Card>
|
||||||
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
|
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
|
||||||
<Card title="Pinecone" href="/components/vectordbs/dbs/pinecone"></Card>
|
<Card title="Pinecone" href="/components/vectordbs/dbs/pinecone"></Card>
|
||||||
<Card title="Azure" href="/components/vectordbs/dbs/azure"></Card>
|
<Card title="Azure" href="/components/vectordbs/dbs/azure"></Card>
|
||||||
|
|||||||
36
mem0/configs/vector_stores/upstash_vector.py
Normal file
36
mem0/configs/vector_stores/upstash_vector.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
from typing import Any, ClassVar, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
try:
|
||||||
|
from upstash_vector import Index
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
|
||||||
|
|
||||||
|
|
||||||
|
class UpstashVectorConfig(BaseModel):
|
||||||
|
Index: ClassVar[type] = Index
|
||||||
|
|
||||||
|
url: Optional[str] = Field(None, description="URL for Upstash Vector index")
|
||||||
|
token: Optional[str] = Field(None, description="Token for Upstash Vector index")
|
||||||
|
client: Optional[Index] = Field(None, description="Existing `upstash_vector.Index` client instance")
|
||||||
|
collection_name: str = Field("mem0", description="Namespace to use for the index")
|
||||||
|
enable_embeddings: bool = Field(
|
||||||
|
False, description="Whether to use built-in upstash embeddings or not. Default is True."
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_credentials_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
client = values.get("client")
|
||||||
|
url = values.get("url") or os.environ.get("UPSTASH_VECTOR_REST_URL")
|
||||||
|
token = values.get("token") or os.environ.get("UPSTASH_VECTOR_REST_TOKEN")
|
||||||
|
|
||||||
|
if not client and not (url and token):
|
||||||
|
raise ValueError("Either a client or URL and token must be provided.")
|
||||||
|
return values
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
}
|
||||||
11
mem0/embeddings/mock.py
Normal file
11
mem0/embeddings/mock.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from mem0.embeddings.base import EmbeddingBase
|
||||||
|
|
||||||
|
|
||||||
|
class MockEmbeddings(EmbeddingBase):
|
||||||
|
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||||
|
"""
|
||||||
|
Generate a mock embedding with dimension of 10.
|
||||||
|
"""
|
||||||
|
return [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
|
||||||
@@ -12,12 +12,20 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from mem0.configs.base import MemoryConfig, MemoryItem
|
from mem0.configs.base import MemoryConfig, MemoryItem
|
||||||
from mem0.configs.enums import MemoryType
|
from mem0.configs.enums import MemoryType
|
||||||
from mem0.configs.prompts import PROCEDURAL_MEMORY_SYSTEM_PROMPT, get_update_memory_messages
|
from mem0.configs.prompts import (
|
||||||
|
PROCEDURAL_MEMORY_SYSTEM_PROMPT,
|
||||||
|
get_update_memory_messages,
|
||||||
|
)
|
||||||
from mem0.memory.base import MemoryBase
|
from mem0.memory.base import MemoryBase
|
||||||
from mem0.memory.setup import setup_config
|
from mem0.memory.setup import setup_config
|
||||||
from mem0.memory.storage import SQLiteManager
|
from mem0.memory.storage import SQLiteManager
|
||||||
from mem0.memory.telemetry import capture_event
|
from mem0.memory.telemetry import capture_event
|
||||||
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages, parse_vision_messages, remove_code_blocks
|
from mem0.memory.utils import (
|
||||||
|
get_fact_retrieval_messages,
|
||||||
|
parse_messages,
|
||||||
|
parse_vision_messages,
|
||||||
|
remove_code_blocks,
|
||||||
|
)
|
||||||
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
|
||||||
|
|
||||||
# Setup user config
|
# Setup user config
|
||||||
@@ -32,7 +40,11 @@ class Memory(MemoryBase):
|
|||||||
|
|
||||||
self.custom_fact_extraction_prompt = self.config.custom_fact_extraction_prompt
|
self.custom_fact_extraction_prompt = self.config.custom_fact_extraction_prompt
|
||||||
self.custom_update_memory_prompt = self.config.custom_update_memory_prompt
|
self.custom_update_memory_prompt = self.config.custom_update_memory_prompt
|
||||||
self.embedding_model = EmbedderFactory.create(self.config.embedder.provider, self.config.embedder.config)
|
self.embedding_model = EmbedderFactory.create(
|
||||||
|
self.config.embedder.provider,
|
||||||
|
self.config.embedder.config,
|
||||||
|
self.config.vector_store.config,
|
||||||
|
)
|
||||||
self.vector_store = VectorStoreFactory.create(
|
self.vector_store = VectorStoreFactory.create(
|
||||||
self.config.vector_store.provider, self.config.vector_store.config
|
self.config.vector_store.provider, self.config.vector_store.config
|
||||||
)
|
)
|
||||||
@@ -260,7 +272,9 @@ class Memory(MemoryBase):
|
|||||||
continue
|
continue
|
||||||
elif resp.get("event") == "ADD":
|
elif resp.get("event") == "ADD":
|
||||||
memory_id = self._create_memory(
|
memory_id = self._create_memory(
|
||||||
data=resp.get("text"), existing_embeddings=new_message_embeddings, metadata=metadata
|
data=resp.get("text"),
|
||||||
|
existing_embeddings=new_message_embeddings,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
returned_memories.append(
|
returned_memories.append(
|
||||||
{
|
{
|
||||||
@@ -300,7 +314,11 @@ class Memory(MemoryBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error in new_memories_with_actions: {e}")
|
logging.error(f"Error in new_memories_with_actions: {e}")
|
||||||
|
|
||||||
capture_event("mem0.add", self, {"version": self.api_version, "keys": list(filters.keys())})
|
capture_event(
|
||||||
|
"mem0.add",
|
||||||
|
self,
|
||||||
|
{"version": self.api_version, "keys": list(filters.keys())},
|
||||||
|
)
|
||||||
|
|
||||||
return returned_memories
|
return returned_memories
|
||||||
|
|
||||||
@@ -342,7 +360,16 @@ class Memory(MemoryBase):
|
|||||||
).model_dump(exclude={"score"})
|
).model_dump(exclude={"score"})
|
||||||
|
|
||||||
# Add metadata if there are additional keys
|
# Add metadata if there are additional keys
|
||||||
excluded_keys = {"user_id", "agent_id", "run_id", "hash", "data", "created_at", "updated_at", "id"}
|
excluded_keys = {
|
||||||
|
"user_id",
|
||||||
|
"agent_id",
|
||||||
|
"run_id",
|
||||||
|
"hash",
|
||||||
|
"data",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
"id",
|
||||||
|
}
|
||||||
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
|
additional_metadata = {k: v for k, v in memory.payload.items() if k not in excluded_keys}
|
||||||
if additional_metadata:
|
if additional_metadata:
|
||||||
memory_item["metadata"] = additional_metadata
|
memory_item["metadata"] = additional_metadata
|
||||||
@@ -631,7 +658,7 @@ class Memory(MemoryBase):
|
|||||||
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
|
prompt (str, optional): Prompt to use for the procedural memory creation. Defaults to None.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from langchain_core.messages.utils import convert_to_messages # type: ignore
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory."
|
"Import error while loading langchain-core. Please install 'langchain-core' to use procedural memory."
|
||||||
@@ -643,7 +670,10 @@ class Memory(MemoryBase):
|
|||||||
parsed_messages = [
|
parsed_messages = [
|
||||||
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
|
{"role": "system", "content": prompt or PROCEDURAL_MEMORY_SYSTEM_PROMPT},
|
||||||
*messages,
|
*messages,
|
||||||
{"role": "user", "content": "Create procedural memory of the above conversation."},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Create procedural memory of the above conversation.",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -728,7 +758,9 @@ class Memory(MemoryBase):
|
|||||||
self.vector_store = VectorStoreFactory.create(
|
self.vector_store = VectorStoreFactory.create(
|
||||||
self.config.vector_store.provider, self.config.vector_store.config
|
self.config.vector_store.provider, self.config.vector_store.config
|
||||||
)
|
)
|
||||||
|
print("before dbreset")
|
||||||
self.db.reset()
|
self.db.reset()
|
||||||
|
print("after dbreset")
|
||||||
capture_event("mem0.reset", self)
|
capture_event("mem0.reset", self)
|
||||||
|
|
||||||
def chat(self, query):
|
def chat(self, query):
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
|
from mem0.embeddings.mock import MockEmbeddings
|
||||||
|
|
||||||
|
|
||||||
def load_class(class_type):
|
def load_class(class_type):
|
||||||
@@ -54,7 +56,9 @@ class EmbedderFactory:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, provider_name, config):
|
def create(cls, provider_name, config, vector_config: Optional[dict]):
|
||||||
|
if provider_name == "upstash_vector" and vector_config and vector_config.enable_embeddings:
|
||||||
|
return MockEmbeddings()
|
||||||
class_type = cls.provider_to_class.get(provider_name)
|
class_type = cls.provider_to_class.get(provider_name)
|
||||||
if class_type:
|
if class_type:
|
||||||
embedder_instance = load_class(class_type)
|
embedder_instance = load_class(class_type)
|
||||||
@@ -70,6 +74,7 @@ class VectorStoreFactory:
|
|||||||
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
"chroma": "mem0.vector_stores.chroma.ChromaDB",
|
||||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||||
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
||||||
|
"upstash_vector": "mem0.vector_stores.upstash_vector.UpstashVector",
|
||||||
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
||||||
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
|
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
|
||||||
"redis": "mem0.vector_stores.redis.RedisDB",
|
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, model_validator
|
|||||||
|
|
||||||
class VectorStoreConfig(BaseModel):
|
class VectorStoreConfig(BaseModel):
|
||||||
provider: str = Field(
|
provider: str = Field(
|
||||||
description="Provider of the vector store (e.g., 'qdrant', 'chroma')",
|
description="Provider of the vector store (e.g., 'qdrant', 'chroma', 'upstash_vector')",
|
||||||
default="qdrant",
|
default="qdrant",
|
||||||
)
|
)
|
||||||
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
|
config: Optional[Dict] = Field(description="Configuration for the specific vector store", default=None)
|
||||||
@@ -16,6 +16,7 @@ class VectorStoreConfig(BaseModel):
|
|||||||
"pgvector": "PGVectorConfig",
|
"pgvector": "PGVectorConfig",
|
||||||
"pinecone": "PineconeConfig",
|
"pinecone": "PineconeConfig",
|
||||||
"milvus": "MilvusDBConfig",
|
"milvus": "MilvusDBConfig",
|
||||||
|
"upstash_vector": "UpstashVectorConfig",
|
||||||
"azure_ai_search": "AzureAISearchConfig",
|
"azure_ai_search": "AzureAISearchConfig",
|
||||||
"redis": "RedisDBConfig",
|
"redis": "RedisDBConfig",
|
||||||
"elasticsearch": "ElasticsearchConfig",
|
"elasticsearch": "ElasticsearchConfig",
|
||||||
|
|||||||
287
mem0/vector_stores/upstash_vector.py
Normal file
287
mem0/vector_stores/upstash_vector.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from mem0.vector_stores.base import VectorStoreBase
|
||||||
|
|
||||||
|
try:
|
||||||
|
from upstash_vector import Index
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The 'upstash_vector' library is required. Please install it using 'pip install upstash_vector'.")
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputData(BaseModel):
|
||||||
|
id: Optional[str] # memory id
|
||||||
|
score: Optional[float] # is None for `get` method
|
||||||
|
payload: Optional[Dict] # metadata
|
||||||
|
|
||||||
|
|
||||||
|
class UpstashVector(VectorStoreBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: str,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
client: Optional[Index] = None,
|
||||||
|
enable_embeddings: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the UpstashVector vector store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str, optional): URL for Upstash Vector index. Defaults to None.
|
||||||
|
token (int, optional): Token for Upstash Vector index. Defaults to None.
|
||||||
|
client (Index, optional): Existing `upstash_vector.Index` client instance. Defaults to None.
|
||||||
|
namespace (str, optional): Default namespace for the index. Defaults to None.
|
||||||
|
"""
|
||||||
|
if client:
|
||||||
|
self.client = client
|
||||||
|
elif url and token:
|
||||||
|
self.client = Index(url, token)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either a client or URL and token must be provided.")
|
||||||
|
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
self.enable_embeddings = enable_embeddings
|
||||||
|
|
||||||
|
def insert(
|
||||||
|
self,
|
||||||
|
vectors: List[list],
|
||||||
|
payloads: Optional[List[Dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Insert vectors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors (list): List of vectors to insert.
|
||||||
|
payloads (list, optional): List of payloads corresponding to vectors. These will be passed as metadatas to the Upstash Vector client. Defaults to None.
|
||||||
|
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||||
|
"""
|
||||||
|
logger.info(f"Inserting {len(vectors)} vectors into namespace {self.collection_name}")
|
||||||
|
|
||||||
|
if self.enable_embeddings:
|
||||||
|
if not payloads or any("data" not in m or m["data"] is None for m in payloads):
|
||||||
|
raise ValueError("When embeddings are enabled, all payloads must contain a 'data' field.")
|
||||||
|
processed_vectors = [
|
||||||
|
{
|
||||||
|
"id": ids[i] if ids else None,
|
||||||
|
"data": payloads[i]["data"],
|
||||||
|
"metadata": payloads[i],
|
||||||
|
}
|
||||||
|
for i, v in enumerate(vectors)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
processed_vectors = [
|
||||||
|
{
|
||||||
|
"id": ids[i] if ids else None,
|
||||||
|
"vector": vectors[i],
|
||||||
|
"metadata": payloads[i] if payloads else None,
|
||||||
|
}
|
||||||
|
for i, v in enumerate(vectors)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.client.upsert(
|
||||||
|
vectors=processed_vectors,
|
||||||
|
namespace=self.collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stringify(self, x):
|
||||||
|
return f'"{x}"' if isinstance(x, str) else x
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vectors: List[list],
|
||||||
|
limit: int = 5,
|
||||||
|
filters: Optional[Dict] = None,
|
||||||
|
) -> List[OutputData]:
|
||||||
|
"""
|
||||||
|
Search for similar vectors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (list): Query vector.
|
||||||
|
limit (int, optional): Number of results to return. Defaults to 5.
|
||||||
|
filters (Dict, optional): Filters to apply to the search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: Search results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
|
||||||
|
|
||||||
|
response = []
|
||||||
|
|
||||||
|
if self.enable_embeddings:
|
||||||
|
response = self.client.query(
|
||||||
|
data=query,
|
||||||
|
top_k=limit,
|
||||||
|
filter=filters_str or "",
|
||||||
|
include_metadata=True,
|
||||||
|
namespace=self.collection_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
queries = [
|
||||||
|
{
|
||||||
|
"vector": v,
|
||||||
|
"top_k": limit,
|
||||||
|
"filter": filters_str or "",
|
||||||
|
"include_metadata": True,
|
||||||
|
"namespace": self.collection_name,
|
||||||
|
}
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
responses = self.client.query_many(queries=queries)
|
||||||
|
# flatten
|
||||||
|
response = [res for res_list in responses for res in res_list]
|
||||||
|
|
||||||
|
return [
|
||||||
|
OutputData(
|
||||||
|
id=res.id,
|
||||||
|
score=res.score,
|
||||||
|
payload=res.metadata,
|
||||||
|
)
|
||||||
|
for res in response
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete(self, vector_id: int):
|
||||||
|
"""
|
||||||
|
Delete a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (int): ID of the vector to delete.
|
||||||
|
"""
|
||||||
|
self.client.delete(
|
||||||
|
ids=[str(vector_id)],
|
||||||
|
namespace=self.collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
vector_id: int,
|
||||||
|
vector: Optional[list] = None,
|
||||||
|
payload: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a vector and its payload.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (int): ID of the vector to update.
|
||||||
|
vector (list, optional): Updated vector. Defaults to None.
|
||||||
|
payload (dict, optional): Updated payload. Defaults to None.
|
||||||
|
"""
|
||||||
|
self.client.update(
|
||||||
|
id=str(vector_id),
|
||||||
|
vector=vector,
|
||||||
|
data=payload.get("data") if payload else None,
|
||||||
|
metadata=payload,
|
||||||
|
namespace=self.collection_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, vector_id: int) -> Optional[OutputData]:
|
||||||
|
"""
|
||||||
|
Retrieve a vector by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_id (int): ID of the vector to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Retrieved vector.
|
||||||
|
"""
|
||||||
|
response = self.client.fetch(
|
||||||
|
ids=[str(vector_id)],
|
||||||
|
namespace=self.collection_name,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
if len(response) == 0:
|
||||||
|
return None
|
||||||
|
vector = response[0]
|
||||||
|
if not vector:
|
||||||
|
return None
|
||||||
|
return OutputData(id=vector.id, score=None, payload=vector.metadata)
|
||||||
|
|
||||||
|
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[List[OutputData]]:
|
||||||
|
"""
|
||||||
|
List all memories.
|
||||||
|
Args:
|
||||||
|
filters (Dict, optional): Filters to apply to the search. Defaults to None.
|
||||||
|
limit (int, optional): Number of results to return. Defaults to 100.
|
||||||
|
Returns:
|
||||||
|
List[OutputData]: Search results.
|
||||||
|
"""
|
||||||
|
filters_str = " AND ".join([f"{k} = {self._stringify(v)}" for k, v in filters.items()]) if filters else None
|
||||||
|
|
||||||
|
info = self.client.info()
|
||||||
|
ns_info = info.namespaces.get(self.collection_name)
|
||||||
|
|
||||||
|
if not ns_info or ns_info.vector_count == 0:
|
||||||
|
return [[]]
|
||||||
|
|
||||||
|
random_vector = [1.0] * self.client.info().dimension
|
||||||
|
|
||||||
|
results, query = self.client.resumable_query(
|
||||||
|
vector=random_vector,
|
||||||
|
filter=filters_str or "",
|
||||||
|
include_metadata=True,
|
||||||
|
namespace=self.collection_name,
|
||||||
|
top_k=100,
|
||||||
|
)
|
||||||
|
with query:
|
||||||
|
while True:
|
||||||
|
if len(results) >= limit:
|
||||||
|
break
|
||||||
|
res = query.fetch_next(100)
|
||||||
|
if not res:
|
||||||
|
break
|
||||||
|
results.extend(res)
|
||||||
|
|
||||||
|
parsed_result = [
|
||||||
|
OutputData(
|
||||||
|
id=res.id,
|
||||||
|
score=res.score,
|
||||||
|
payload=res.metadata,
|
||||||
|
)
|
||||||
|
for res in results
|
||||||
|
]
|
||||||
|
return [parsed_result]
|
||||||
|
|
||||||
|
def create_col(self, name, vector_size, distance):
|
||||||
|
"""
|
||||||
|
Upstash Vector has namespaces instead of collections. A namespace is created when the first vector is inserted.
|
||||||
|
|
||||||
|
This method is a placeholder to maintain the interface.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def list_cols(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Lists all namespaces in the Upstash Vector index.
|
||||||
|
Returns:
|
||||||
|
List[str]: List of namespaces.
|
||||||
|
"""
|
||||||
|
return self.client.list_namespaces()
|
||||||
|
|
||||||
|
def delete_col(self):
|
||||||
|
"""
|
||||||
|
Delete the namespace and all vectors in it.
|
||||||
|
"""
|
||||||
|
self.client.reset(namespace=self.collection_name)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def col_info(self):
|
||||||
|
"""
|
||||||
|
Return general information about the Upstash Vector index.
|
||||||
|
|
||||||
|
- Total number of vectors across all namespaces
|
||||||
|
- Total number of vectors waiting to be indexed across all namespaces
|
||||||
|
- Total size of the index on disk in bytes
|
||||||
|
- Vector dimension
|
||||||
|
- Similarity function used
|
||||||
|
- Per-namespace vector and pending vector counts
|
||||||
|
"""
|
||||||
|
return self.client.info()
|
||||||
384
tests/vector_stores/test_upstash_vector.py
Normal file
384
tests/vector_stores/test_upstash_vector.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from mem0.vector_stores.upstash_vector import UpstashVector
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QueryResult:
|
||||||
|
id: str
|
||||||
|
score: Optional[float]
|
||||||
|
vector: Optional[List[float]] = None
|
||||||
|
metadata: Optional[Dict] = None
|
||||||
|
data: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_index():
|
||||||
|
with patch("upstash_vector.Index") as mock_index:
|
||||||
|
yield mock_index
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def upstash_instance(mock_index):
|
||||||
|
return UpstashVector(client=mock_index.return_value, collection_name="ns")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def upstash_instance_with_embeddings(mock_index):
|
||||||
|
return UpstashVector(
|
||||||
|
client=mock_index.return_value, collection_name="ns", enable_embeddings=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_vectors(upstash_instance, mock_index):
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||||
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
|
upstash_instance.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
|
upstash_instance.client.upsert.assert_called_once_with(
|
||||||
|
vectors=[
|
||||||
|
{"id": "id1", "vector": [0.1, 0.2, 0.3], "metadata": {"name": "vector1"}},
|
||||||
|
{"id": "id2", "vector": [0.4, 0.5, 0.6], "metadata": {"name": "vector2"}},
|
||||||
|
],
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_vectors(upstash_instance, mock_index):
|
||||||
|
mock_result = [
|
||||||
|
QueryResult(
|
||||||
|
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data=None
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
upstash_instance.client.query_many.return_value = [mock_result]
|
||||||
|
|
||||||
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
results = upstash_instance.search(
|
||||||
|
query="hello world",
|
||||||
|
vectors=vectors,
|
||||||
|
limit=2,
|
||||||
|
filters={"age": 30, "name": "John"},
|
||||||
|
)
|
||||||
|
|
||||||
|
upstash_instance.client.query_many.assert_called_once_with(
|
||||||
|
queries=[
|
||||||
|
{
|
||||||
|
"vector": vectors[0],
|
||||||
|
"top_k": 2,
|
||||||
|
"namespace": "ns",
|
||||||
|
"include_metadata": True,
|
||||||
|
"filter": 'age = 30 AND name = "John"',
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
assert results[0].score == 0.1
|
||||||
|
assert results[0].payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_vector(upstash_instance):
|
||||||
|
vector_id = "id1"
|
||||||
|
|
||||||
|
upstash_instance.delete(vector_id=vector_id)
|
||||||
|
|
||||||
|
upstash_instance.client.delete.assert_called_once_with(
|
||||||
|
ids=[vector_id], namespace="ns"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_vector(upstash_instance):
|
||||||
|
vector_id = "id1"
|
||||||
|
new_vector = [0.7, 0.8, 0.9]
|
||||||
|
new_payload = {"name": "updated_vector"}
|
||||||
|
|
||||||
|
upstash_instance.update(vector_id=vector_id, vector=new_vector, payload=new_payload)
|
||||||
|
|
||||||
|
upstash_instance.client.update.assert_called_once_with(
|
||||||
|
id="id1",
|
||||||
|
vector=new_vector,
|
||||||
|
data=None,
|
||||||
|
metadata={"name": "updated_vector"},
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_vector(upstash_instance):
|
||||||
|
mock_result = [
|
||||||
|
QueryResult(
|
||||||
|
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
upstash_instance.client.fetch.return_value = mock_result
|
||||||
|
|
||||||
|
result = upstash_instance.get(vector_id="id1")
|
||||||
|
|
||||||
|
upstash_instance.client.fetch.assert_called_once_with(
|
||||||
|
ids=["id1"], namespace="ns", include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.id == "id1"
|
||||||
|
assert result.payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_vectors(upstash_instance):
|
||||||
|
mock_result = [
|
||||||
|
QueryResult(
|
||||||
|
id="id1", score=None, vector=None, metadata={"name": "vector1"}, data=None
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
id="id2", score=None, vector=None, metadata={"name": "vector2"}, data=None
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
id="id3", score=None, vector=None, metadata={"name": "vector3"}, data=None
|
||||||
|
),
|
||||||
|
]
|
||||||
|
handler = MagicMock()
|
||||||
|
|
||||||
|
upstash_instance.client.info.return_value.dimension = 10
|
||||||
|
upstash_instance.client.resumable_query.return_value = (mock_result[0:1], handler)
|
||||||
|
handler.fetch_next.side_effect = [mock_result[1:2], mock_result[2:3], []]
|
||||||
|
|
||||||
|
filters = {"age": 30, "name": "John"}
|
||||||
|
print("filters", filters)
|
||||||
|
[results] = upstash_instance.list(filters=filters, limit=15)
|
||||||
|
|
||||||
|
upstash_instance.client.info.return_value = {
|
||||||
|
"dimension": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
upstash_instance.client.resumable_query.assert_called_once_with(
|
||||||
|
vector=[1.0] * 10,
|
||||||
|
filter='age = 30 AND name = "John"',
|
||||||
|
include_metadata=True,
|
||||||
|
namespace="ns",
|
||||||
|
top_k=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler.fetch_next.assert_has_calls([call(100), call(100), call(100)])
|
||||||
|
handler.__exit__.assert_called_once()
|
||||||
|
|
||||||
|
assert len(results) == len(mock_result)
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
assert results[0].payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
payloads = [
|
||||||
|
{"name": "vector1", "data": "data1"},
|
||||||
|
{"name": "vector2", "data": "data2"},
|
||||||
|
]
|
||||||
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.insert(vectors=vectors, payloads=payloads, ids=ids)
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.client.upsert.assert_called_once_with(
|
||||||
|
vectors=[
|
||||||
|
{
|
||||||
|
"id": "id1",
|
||||||
|
# Uses the data field instead of using vectors
|
||||||
|
"data": "data1",
|
||||||
|
"metadata": {"name": "vector1", "data": "data1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "id2",
|
||||||
|
"data": "data2",
|
||||||
|
"metadata": {"name": "vector2", "data": "data2"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_vectors_with_embeddings(upstash_instance_with_embeddings, mock_index):
|
||||||
|
mock_result = [
|
||||||
|
QueryResult(
|
||||||
|
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data="data1"
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
id="id2", score=0.2, vector=None, metadata={"name": "vector2"}, data="data2"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.client.query.return_value = mock_result
|
||||||
|
|
||||||
|
results = upstash_instance_with_embeddings.search(
|
||||||
|
query="hello world",
|
||||||
|
vectors=[],
|
||||||
|
limit=2,
|
||||||
|
filters={"age": 30, "name": "John"},
|
||||||
|
)
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.client.query.assert_called_once_with(
|
||||||
|
# Uses the data field instead of using vectors
|
||||||
|
data="hello world",
|
||||||
|
top_k=2,
|
||||||
|
filter='age = 30 AND name = "John"',
|
||||||
|
include_metadata=True,
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
assert results[0].score == 0.1
|
||||||
|
assert results[0].payload == {"name": "vector1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_vector_with_embeddings(upstash_instance_with_embeddings):
|
||||||
|
vector_id = "id1"
|
||||||
|
new_payload = {"name": "updated_vector", "data": "updated_data"}
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.update(vector_id=vector_id, payload=new_payload)
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.client.update.assert_called_once_with(
|
||||||
|
id="id1",
|
||||||
|
vector=None,
|
||||||
|
data="updated_data",
|
||||||
|
metadata={"name": "updated_vector", "data": "updated_data"},
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_vectors_with_embeddings_missing_data(upstash_instance_with_embeddings):
|
||||||
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
payloads = [{"name": "vector1"}] # Missing data field
|
||||||
|
ids = ["id1"]
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="When embeddings are enabled, all payloads must contain a 'data' field",
|
||||||
|
):
|
||||||
|
upstash_instance_with_embeddings.insert(
|
||||||
|
vectors=vectors, payloads=payloads, ids=ids
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_vector_with_embeddings_missing_data(upstash_instance_with_embeddings):
|
||||||
|
# Should still work, data is not required for update
|
||||||
|
vector_id = "id1"
|
||||||
|
new_payload = {"name": "updated_vector"} # Missing data field
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.update(vector_id=vector_id, payload=new_payload)
|
||||||
|
|
||||||
|
upstash_instance_with_embeddings.client.update.assert_called_once_with(
|
||||||
|
id="id1",
|
||||||
|
vector=None,
|
||||||
|
data=None,
|
||||||
|
metadata={"name": "updated_vector"},
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_cols(upstash_instance):
|
||||||
|
mock_namespaces = ["ns1", "ns2", "ns3"]
|
||||||
|
upstash_instance.client.list_namespaces.return_value = mock_namespaces
|
||||||
|
|
||||||
|
result = upstash_instance.list_cols()
|
||||||
|
|
||||||
|
upstash_instance.client.list_namespaces.assert_called_once()
|
||||||
|
assert result == mock_namespaces
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_col(upstash_instance):
|
||||||
|
upstash_instance.delete_col()
|
||||||
|
upstash_instance.client.reset.assert_called_once_with(namespace="ns")
|
||||||
|
|
||||||
|
|
||||||
|
def test_col_info(upstash_instance):
|
||||||
|
mock_info = {
|
||||||
|
"dimension": 10,
|
||||||
|
"total_vectors": 100,
|
||||||
|
"pending_vectors": 0,
|
||||||
|
"disk_size": 1024,
|
||||||
|
}
|
||||||
|
upstash_instance.client.info.return_value = mock_info
|
||||||
|
|
||||||
|
result = upstash_instance.col_info()
|
||||||
|
|
||||||
|
upstash_instance.client.info.assert_called_once()
|
||||||
|
assert result == mock_info
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_vector_not_found(upstash_instance):
|
||||||
|
upstash_instance.client.fetch.return_value = []
|
||||||
|
|
||||||
|
result = upstash_instance.get(vector_id="nonexistent")
|
||||||
|
|
||||||
|
upstash_instance.client.fetch.assert_called_once_with(
|
||||||
|
ids=["nonexistent"], namespace="ns", include_metadata=True
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_vectors_empty_filters(upstash_instance):
|
||||||
|
mock_result = [
|
||||||
|
QueryResult(
|
||||||
|
id="id1", score=0.1, vector=None, metadata={"name": "vector1"}, data=None
|
||||||
|
)
|
||||||
|
]
|
||||||
|
upstash_instance.client.query_many.return_value = [mock_result]
|
||||||
|
|
||||||
|
vectors = [[0.1, 0.2, 0.3]]
|
||||||
|
results = upstash_instance.search(
|
||||||
|
query="hello world",
|
||||||
|
vectors=vectors,
|
||||||
|
limit=1,
|
||||||
|
filters=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
upstash_instance.client.query_many.assert_called_once_with(
|
||||||
|
queries=[
|
||||||
|
{
|
||||||
|
"vector": vectors[0],
|
||||||
|
"top_k": 1,
|
||||||
|
"namespace": "ns",
|
||||||
|
"include_metadata": True,
|
||||||
|
"filter": "",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].id == "id1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_vectors_no_payloads(upstash_instance):
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
ids = ["id1", "id2"]
|
||||||
|
|
||||||
|
upstash_instance.insert(vectors=vectors, ids=ids)
|
||||||
|
|
||||||
|
upstash_instance.client.upsert.assert_called_once_with(
|
||||||
|
vectors=[
|
||||||
|
{"id": "id1", "vector": [0.1, 0.2, 0.3], "metadata": None},
|
||||||
|
{"id": "id2", "vector": [0.4, 0.5, 0.6], "metadata": None},
|
||||||
|
],
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_insert_vectors_no_ids(upstash_instance):
|
||||||
|
vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||||
|
|
||||||
|
upstash_instance.insert(vectors=vectors, payloads=payloads)
|
||||||
|
|
||||||
|
upstash_instance.client.upsert.assert_called_once_with(
|
||||||
|
vectors=[
|
||||||
|
{"id": None, "vector": [0.1, 0.2, 0.3], "metadata": {"name": "vector1"}},
|
||||||
|
{"id": None, "vector": [0.4, 0.5, 0.6], "metadata": {"name": "vector2"}},
|
||||||
|
],
|
||||||
|
namespace="ns",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user