Add: Pinecone integration (#2395)
This commit is contained in:
2
Makefile
2
Makefile
@@ -13,7 +13,7 @@ install:
|
||||
install_all:
|
||||
poetry install
|
||||
poetry run pip install groq together boto3 litellm ollama chromadb weaviate weaviate-client sentence_transformers vertexai \
|
||||
google-generativeai elasticsearch opensearch-py vecs
|
||||
google-generativeai elasticsearch opensearch-py vecs pinecone pinecone-text
|
||||
|
||||
# Format code with ruff
|
||||
format:
|
||||
|
||||
88
docs/components/vectordbs/dbs/pinecone.mdx
Normal file
88
docs/components/vectordbs/dbs/pinecone.mdx
Normal file
@@ -0,0 +1,88 @@
|
||||
# Pinecone
|
||||
|
||||
[Pinecone](https://www.pinecone.io/) is a fully managed vector database designed for machine learning applications, offering high performance vector search with low latency at scale. It's particularly well-suited for semantic search, recommendation systems, and other AI-powered applications.
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "sk-xx"
|
||||
os.environ["PINECONE_API_KEY"] = "your-api-key"
|
||||
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "pinecone",
|
||||
"config": {
|
||||
"collection_name": "memory_index",
|
||||
"embedding_model_dims": 1536,
|
||||
"environment": "us-west1-gcp",
|
||||
"metric": "cosine"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = Memory.from_config(config)
|
||||
messages = [
|
||||
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
|
||||
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
|
||||
{"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."},
|
||||
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
|
||||
]
|
||||
m.add(messages, user_id="alice", metadata={"category": "movies"})
|
||||
```
|
||||
|
||||
### Config
|
||||
|
||||
Here are the parameters available for configuring Pinecone:
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `collection_name` | Name of the index/collection | Required |
|
||||
| `embedding_model_dims` | Dimensions of the embedding model | Required |
|
||||
| `client` | Existing Pinecone client instance | `None` |
|
||||
| `api_key` | API key for Pinecone | Environment variable: `PINECONE_API_KEY` |
|
||||
| `environment` | Pinecone environment | `None` |
|
||||
| `serverless_config` | Configuration for serverless deployment | `None` |
|
||||
| `pod_config` | Configuration for pod-based deployment | `None` |
|
||||
| `hybrid_search` | Whether to enable hybrid search | `False` |
|
||||
| `metric` | Distance metric for vector similarity | `"cosine"` |
|
||||
| `batch_size` | Batch size for operations | `100` |
|
||||
|
||||
#### Serverless Config Example
|
||||
|
||||
```python
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "pinecone",
|
||||
"config": {
|
||||
"collection_name": "memory_index",
|
||||
"embedding_model_dims": 1536,
|
||||
"serverless_config": {
|
||||
"cloud": "aws",
|
||||
"region": "us-west-2"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Pod Config Example
|
||||
|
||||
```python
|
||||
config = {
|
||||
"vector_store": {
|
||||
"provider": "pinecone",
|
||||
"config": {
|
||||
"collection_name": "memory_index",
|
||||
"embedding_model_dims": 1536,
|
||||
"pod_config": {
|
||||
"environment": "gcp-starter",
|
||||
"replicas": 1,
|
||||
"pod_type": "starter"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -19,6 +19,7 @@ See the list of supported vector databases below.
|
||||
<Card title="Chroma" href="/components/vectordbs/dbs/chroma"></Card>
|
||||
<Card title="Pgvector" href="/components/vectordbs/dbs/pgvector"></Card>
|
||||
<Card title="Milvus" href="/components/vectordbs/dbs/milvus"></Card>
|
||||
<Card title="Pinecone" href="/components/vectordbs/dbs/pinecone"></Card>
|
||||
<Card title="Azure AI Search" href="/components/vectordbs/dbs/azure_ai_search"></Card>
|
||||
<Card title="Redis" href="/components/vectordbs/dbs/redis"></Card>
|
||||
<Card title="Elasticsearch" href="/components/vectordbs/dbs/elasticsearch"></Card>
|
||||
|
||||
@@ -129,6 +129,7 @@
|
||||
"components/vectordbs/dbs/chroma",
|
||||
"components/vectordbs/dbs/pgvector",
|
||||
"components/vectordbs/dbs/milvus",
|
||||
"components/vectordbs/dbs/pinecone",
|
||||
"components/vectordbs/dbs/azure_ai_search",
|
||||
"components/vectordbs/dbs/redis",
|
||||
"components/vectordbs/dbs/elasticsearch",
|
||||
|
||||
56
mem0/configs/vector_stores/pinecone.py
Normal file
56
mem0/configs/vector_stores/pinecone.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class PineconeConfig(BaseModel):
|
||||
"""Configuration for Pinecone vector database."""
|
||||
|
||||
collection_name: str = Field("mem0", description="Name of the index/collection")
|
||||
embedding_model_dims: int = Field(1536, description="Dimensions of the embedding model")
|
||||
client: Optional[Any] = Field(None, description="Existing Pinecone client instance")
|
||||
api_key: Optional[str] = Field(None, description="API key for Pinecone")
|
||||
environment: Optional[str] = Field(None, description="Pinecone environment")
|
||||
serverless_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for serverless deployment")
|
||||
pod_config: Optional[Dict[str, Any]] = Field(None, description="Configuration for pod-based deployment")
|
||||
hybrid_search: bool = Field(False, description="Whether to enable hybrid search")
|
||||
metric: str = Field("cosine", description="Distance metric for vector similarity")
|
||||
batch_size: int = Field(100, description="Batch size for operations")
|
||||
extra_params: Optional[Dict[str, Any]] = Field(None, description="Additional parameters for Pinecone client")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_api_key_or_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
api_key, client = values.get("api_key"), values.get("client")
|
||||
if not api_key and not client and "PINECONE_API_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"Either 'api_key' or 'client' must be provided, or PINECONE_API_KEY environment variable must be set."
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_pod_or_serverless(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pod_config, serverless_config = values.get("pod_config"), values.get("serverless_config")
|
||||
if pod_config and serverless_config:
|
||||
raise ValueError(
|
||||
"Both 'pod_config' and 'serverless_config' cannot be specified. Choose one deployment option."
|
||||
)
|
||||
return values
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
allowed_fields = set(cls.model_fields.keys())
|
||||
input_fields = set(values.keys())
|
||||
extra_fields = input_fields - allowed_fields
|
||||
if extra_fields:
|
||||
raise ValueError(
|
||||
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
|
||||
)
|
||||
return values
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
}
|
||||
@@ -67,6 +67,7 @@ class VectorStoreFactory:
|
||||
"pgvector": "mem0.vector_stores.pgvector.PGVector",
|
||||
"milvus": "mem0.vector_stores.milvus.MilvusDB",
|
||||
"azure_ai_search": "mem0.vector_stores.azure_ai_search.AzureAISearch",
|
||||
"pinecone": "mem0.vector_stores.pinecone.PineconeDB",
|
||||
"redis": "mem0.vector_stores.redis.RedisDB",
|
||||
"elasticsearch": "mem0.vector_stores.elasticsearch.ElasticsearchDB",
|
||||
"vertex_ai_vector_search": "mem0.vector_stores.vertex_ai_vector_search.GoogleMatchingEngine",
|
||||
|
||||
@@ -14,6 +14,7 @@ class VectorStoreConfig(BaseModel):
|
||||
"qdrant": "QdrantConfig",
|
||||
"chroma": "ChromaDbConfig",
|
||||
"pgvector": "PGVectorConfig",
|
||||
"pinecone": "PineconeConfig",
|
||||
"milvus": "MilvusDBConfig",
|
||||
"azure_ai_search": "AzureAISearchConfig",
|
||||
"redis": "RedisDBConfig",
|
||||
|
||||
368
mem0/vector_stores/pinecone.py
Normal file
368
mem0/vector_stores/pinecone.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from pinecone import Pinecone, PodSpec, ServerlessSpec
|
||||
from pinecone.data.dataclasses.vector import Vector
|
||||
except ImportError:
|
||||
raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None
|
||||
|
||||
from mem0.vector_stores.base import VectorStoreBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputData(BaseModel):
|
||||
id: Optional[str] # memory id
|
||||
score: Optional[float] # distance
|
||||
payload: Optional[Dict] # metadata
|
||||
|
||||
|
||||
class PineconeDB(VectorStoreBase):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embedding_model_dims: int,
|
||||
client: Optional["Pinecone"],
|
||||
api_key: Optional[str],
|
||||
environment: Optional[str],
|
||||
serverless_config: Optional[Dict[str, Any]],
|
||||
pod_config: Optional[Dict[str, Any]],
|
||||
hybrid_search: bool,
|
||||
metric: str,
|
||||
batch_size: int,
|
||||
extra_params: Optional[Dict[str, Any]]
|
||||
):
|
||||
"""
|
||||
Initialize the Pinecone vector store.
|
||||
|
||||
Args:
|
||||
collection_name (str): Name of the index/collection.
|
||||
embedding_model_dims (int): Dimensions of the embedding model.
|
||||
client (Pinecone, optional): Existing Pinecone client instance. Defaults to None.
|
||||
api_key (str, optional): API key for Pinecone. Defaults to None.
|
||||
environment (str, optional): Pinecone environment. Defaults to None.
|
||||
serverless_config (Dict, optional): Configuration for serverless deployment. Defaults to None.
|
||||
pod_config (Dict, optional): Configuration for pod-based deployment. Defaults to None.
|
||||
hybrid_search (bool, optional): Whether to enable hybrid search. Defaults to False.
|
||||
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
|
||||
batch_size (int, optional): Batch size for operations. Defaults to 100.
|
||||
extra_params (Dict, optional): Additional parameters for Pinecone client. Defaults to None.
|
||||
"""
|
||||
if client:
|
||||
self.client = client
|
||||
else:
|
||||
api_key = api_key or os.environ.get("PINECONE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Pinecone API key must be provided either as a parameter or as an environment variable"
|
||||
)
|
||||
|
||||
params = extra_params or {}
|
||||
self.client = Pinecone(api_key=api_key, **params)
|
||||
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model_dims = embedding_model_dims
|
||||
self.environment = environment
|
||||
self.serverless_config = serverless_config
|
||||
self.pod_config = pod_config
|
||||
self.hybrid_search = hybrid_search
|
||||
self.metric = metric
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.sparse_encoder = None
|
||||
if self.hybrid_search:
|
||||
try:
|
||||
from pinecone_text.sparse import BM25Encoder
|
||||
|
||||
logger.info("Initializing BM25Encoder for sparse vectors...")
|
||||
self.sparse_encoder = BM25Encoder.default()
|
||||
except ImportError:
|
||||
logger.warning("pinecone-text not installed. Hybrid search will be disabled.")
|
||||
self.hybrid_search = False
|
||||
|
||||
self.create_col(embedding_model_dims, metric)
|
||||
|
||||
def create_col(self, vector_size: int, metric: str = "cosine"):
|
||||
"""
|
||||
Create a new index/collection.
|
||||
|
||||
Args:
|
||||
vector_size (int): Size of the vectors to be stored.
|
||||
metric (str, optional): Distance metric for vector similarity. Defaults to "cosine".
|
||||
"""
|
||||
existing_indexes = self.list_cols().names()
|
||||
|
||||
if self.collection_name in existing_indexes:
|
||||
logging.debug(f"Index {self.collection_name} already exists. Skipping creation.")
|
||||
self.index = self.client.Index(self.collection_name)
|
||||
return
|
||||
|
||||
if self.serverless_config:
|
||||
spec = ServerlessSpec(**self.serverless_config)
|
||||
elif self.pod_config:
|
||||
spec = PodSpec(**self.pod_config)
|
||||
else:
|
||||
spec = ServerlessSpec(cloud="aws", region="us-west-2")
|
||||
|
||||
self.client.create_index(
|
||||
name=self.collection_name,
|
||||
dimension=vector_size,
|
||||
metric=metric,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
self.index = self.client.Index(self.collection_name)
|
||||
|
||||
def insert(
|
||||
self,
|
||||
vectors: List[List[float]],
|
||||
payloads: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[Union[str, int]]] = None,
|
||||
):
|
||||
"""
|
||||
Insert vectors into an index.
|
||||
|
||||
Args:
|
||||
vectors (list): List of vectors to insert.
|
||||
payloads (list, optional): List of payloads corresponding to vectors. Defaults to None.
|
||||
ids (list, optional): List of IDs corresponding to vectors. Defaults to None.
|
||||
"""
|
||||
logger.info(f"Inserting {len(vectors)} vectors into index {self.collection_name}")
|
||||
items = []
|
||||
|
||||
for idx, vector in enumerate(vectors):
|
||||
item_id = str(ids[idx]) if ids is not None else str(idx)
|
||||
payload = payloads[idx] if payloads else {}
|
||||
|
||||
vector_record = {"id": item_id, "values": vector, "metadata": payload}
|
||||
|
||||
if self.hybrid_search and self.sparse_encoder and "text" in payload:
|
||||
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
|
||||
vector_record["sparse_values"] = sparse_vector
|
||||
|
||||
items.append(vector_record)
|
||||
|
||||
if len(items) >= self.batch_size:
|
||||
self.index.upsert(vectors=items)
|
||||
items = []
|
||||
|
||||
if items:
|
||||
self.index.upsert(vectors=items)
|
||||
|
||||
def _parse_output(self, data: Dict) -> List[OutputData]:
|
||||
"""
|
||||
Parse the output data from Pinecone search results.
|
||||
|
||||
Args:
|
||||
data (Dict): Output data from Pinecone query.
|
||||
|
||||
Returns:
|
||||
List[OutputData]: Parsed output data.
|
||||
"""
|
||||
if isinstance(data, Vector):
|
||||
result = OutputData(
|
||||
id=data.id,
|
||||
score=0.0,
|
||||
payload=data.metadata,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
result = []
|
||||
for match in data:
|
||||
entry = OutputData(
|
||||
id=match.get("id"),
|
||||
score=match.get("score"),
|
||||
payload=match.get("metadata"),
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
|
||||
def _create_filter(self, filters: Optional[Dict]) -> Dict:
|
||||
"""
|
||||
Create a filter dictionary from the provided filters.
|
||||
"""
|
||||
if not filters:
|
||||
return {}
|
||||
|
||||
pinecone_filter = {}
|
||||
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, dict) and "gte" in value and "lte" in value:
|
||||
pinecone_filter[key] = {"$gte": value["gte"], "$lte": value["lte"]}
|
||||
else:
|
||||
pinecone_filter[key] = {"$eq": value}
|
||||
|
||||
return pinecone_filter
|
||||
|
||||
def search(self, query: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
|
||||
"""
|
||||
Search for similar vectors.
|
||||
|
||||
Args:
|
||||
query (list): Query vector.
|
||||
limit (int, optional): Number of results to return. Defaults to 5.
|
||||
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list: Search results.
|
||||
"""
|
||||
filter_dict = self._create_filter(filters) if filters else None
|
||||
|
||||
query_params = {
|
||||
"vector": query,
|
||||
"top_k": limit,
|
||||
"include_metadata": True,
|
||||
"include_values": False,
|
||||
}
|
||||
|
||||
if filter_dict:
|
||||
query_params["filter"] = filter_dict
|
||||
|
||||
if self.hybrid_search and self.sparse_encoder and "text" in filters:
|
||||
query_text = filters.get("text")
|
||||
if query_text:
|
||||
sparse_vector = self.sparse_encoder.encode_queries(query_text)
|
||||
query_params["sparse_vector"] = sparse_vector
|
||||
|
||||
response = self.index.query(**query_params)
|
||||
|
||||
results = self._parse_output(response.matches)
|
||||
return results
|
||||
|
||||
def delete(self, vector_id: Union[str, int]):
|
||||
"""
|
||||
Delete a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (Union[str, int]): ID of the vector to delete.
|
||||
"""
|
||||
self.index.delete(ids=[str(vector_id)])
|
||||
|
||||
def update(self, vector_id: Union[str, int], vector: Optional[List[float]] = None, payload: Optional[Dict] = None):
|
||||
"""
|
||||
Update a vector and its payload.
|
||||
|
||||
Args:
|
||||
vector_id (Union[str, int]): ID of the vector to update.
|
||||
vector (list, optional): Updated vector. Defaults to None.
|
||||
payload (dict, optional): Updated payload. Defaults to None.
|
||||
"""
|
||||
item = {
|
||||
"id": str(vector_id),
|
||||
}
|
||||
|
||||
if vector is not None:
|
||||
item["values"] = vector
|
||||
|
||||
if payload is not None:
|
||||
item["metadata"] = payload
|
||||
|
||||
if self.hybrid_search and self.sparse_encoder and "text" in payload:
|
||||
sparse_vector = self.sparse_encoder.encode_documents(payload["text"])
|
||||
item["sparse_values"] = sparse_vector
|
||||
|
||||
self.index.upsert(vectors=[item])
|
||||
|
||||
def get(self, vector_id: Union[str, int]) -> OutputData:
|
||||
"""
|
||||
Retrieve a vector by ID.
|
||||
|
||||
Args:
|
||||
vector_id (Union[str, int]): ID of the vector to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: Retrieved vector or None if not found.
|
||||
"""
|
||||
try:
|
||||
response = self.index.fetch(ids=[str(vector_id)])
|
||||
if str(vector_id) in response.vectors:
|
||||
return self._parse_output(response.vectors[str(vector_id)])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving vector {vector_id}: {e}")
|
||||
return None
|
||||
|
||||
def list_cols(self):
|
||||
"""
|
||||
List all indexes/collections.
|
||||
|
||||
Returns:
|
||||
list: List of index information.
|
||||
"""
|
||||
return self.client.list_indexes()
|
||||
|
||||
def delete_col(self):
|
||||
"""Delete an index/collection."""
|
||||
try:
|
||||
self.client.delete_index(self.collection_name)
|
||||
logger.info(f"Index {self.collection_name} deleted successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting index {self.collection_name}: {e}")
|
||||
|
||||
def col_info(self) -> Dict:
|
||||
"""
|
||||
Get information about an index/collection.
|
||||
|
||||
Returns:
|
||||
dict: Index information.
|
||||
"""
|
||||
return self.client.describe_index(self.collection_name)
|
||||
|
||||
def list(self, filters: Optional[Dict] = None, limit: int = 100) -> List[OutputData]:
|
||||
"""
|
||||
List vectors in an index with optional filtering.
|
||||
|
||||
Args:
|
||||
filters (dict, optional): Filters to apply to the list. Defaults to None.
|
||||
limit (int, optional): Number of vectors to return. Defaults to 100.
|
||||
|
||||
Returns:
|
||||
dict: List of vectors with their metadata.
|
||||
"""
|
||||
filter_dict = self._create_filter(filters) if filters else None
|
||||
|
||||
stats = self.index.describe_index_stats()
|
||||
dimension = stats.dimension
|
||||
|
||||
zero_vector = [0.0] * dimension
|
||||
|
||||
query_params = {
|
||||
"vector": zero_vector,
|
||||
"top_k": limit,
|
||||
"include_metadata": True,
|
||||
"include_values": True,
|
||||
}
|
||||
|
||||
if filter_dict:
|
||||
query_params["filter"] = filter_dict
|
||||
|
||||
try:
|
||||
response = self.index.query(**query_params)
|
||||
response = response.to_dict()
|
||||
results = self._parse_output(response["matches"])
|
||||
return [results]
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing vectors: {e}")
|
||||
return {"points": [], "next_page_token": None}
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Count number of vectors in the index.
|
||||
|
||||
Returns:
|
||||
int: Total number of vectors.
|
||||
"""
|
||||
stats = self.index.describe_index_stats()
|
||||
return stats.total_vector_count
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the index by deleting and recreating it.
|
||||
"""
|
||||
self.delete_col()
|
||||
self.create_col(self.embedding_model_dims, self.metric)
|
||||
120
tests/vector_stores/test_pinecone.py
Normal file
120
tests/vector_stores/test_pinecone.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from mem0.vector_stores.pinecone import PineconeDB
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pinecone_client():
|
||||
client = MagicMock()
|
||||
client.Index.return_value = MagicMock()
|
||||
client.list_indexes.return_value.names.return_value = []
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def pinecone_db(mock_pinecone_client):
|
||||
return PineconeDB(
|
||||
collection_name="test_index",
|
||||
embedding_model_dims=128,
|
||||
client=mock_pinecone_client,
|
||||
api_key="fake_api_key",
|
||||
environment="us-west1-gcp",
|
||||
serverless_config=None,
|
||||
pod_config=None,
|
||||
hybrid_search=False,
|
||||
metric="cosine",
|
||||
batch_size=100,
|
||||
extra_params=None
|
||||
)
|
||||
|
||||
def test_create_col_existing_index(mock_pinecone_client):
|
||||
# Set up the mock before creating the PineconeDB object
|
||||
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
|
||||
|
||||
pinecone_db = PineconeDB(
|
||||
collection_name="test_index",
|
||||
embedding_model_dims=128,
|
||||
client=mock_pinecone_client,
|
||||
api_key="fake_api_key",
|
||||
environment="us-west1-gcp",
|
||||
serverless_config=None,
|
||||
pod_config=None,
|
||||
hybrid_search=False,
|
||||
metric="cosine",
|
||||
batch_size=100,
|
||||
extra_params=None
|
||||
)
|
||||
|
||||
# Reset the mock to verify it wasn't called during the test
|
||||
mock_pinecone_client.create_index.reset_mock()
|
||||
|
||||
pinecone_db.create_col(128, "cosine")
|
||||
|
||||
mock_pinecone_client.create_index.assert_not_called()
|
||||
|
||||
def test_create_col_new_index(pinecone_db, mock_pinecone_client):
|
||||
mock_pinecone_client.list_indexes.return_value.names.return_value = []
|
||||
pinecone_db.create_col(128, "cosine")
|
||||
mock_pinecone_client.create_index.assert_called()
|
||||
|
||||
def test_insert_vectors(pinecone_db):
|
||||
vectors = [[0.1] * 128, [0.2] * 128]
|
||||
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
||||
ids = ["id1", "id2"]
|
||||
pinecone_db.insert(vectors, payloads, ids)
|
||||
pinecone_db.index.upsert.assert_called()
|
||||
|
||||
def test_search_vectors(pinecone_db):
|
||||
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
|
||||
results = pinecone_db.search([0.1] * 128, limit=1)
|
||||
assert len(results) == 1
|
||||
assert results[0].id == "id1"
|
||||
assert results[0].score == 0.9
|
||||
|
||||
def test_update_vector(pinecone_db):
|
||||
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
|
||||
pinecone_db.index.upsert.assert_called()
|
||||
|
||||
def test_get_vector_found(pinecone_db):
|
||||
# Looking at the _parse_output method, it expects a Vector object
|
||||
# or a list of dictionaries, not a dictionary with an 'id' field
|
||||
|
||||
# Create a mock Vector object
|
||||
from pinecone.data.dataclasses.vector import Vector
|
||||
mock_vector = Vector(
|
||||
id="id1",
|
||||
values=[0.1] * 128,
|
||||
metadata={"name": "vector1"}
|
||||
)
|
||||
|
||||
# Mock the fetch method to return the mock response object
|
||||
mock_response = MagicMock()
|
||||
mock_response.vectors = {"id1": mock_vector}
|
||||
pinecone_db.index.fetch.return_value = mock_response
|
||||
|
||||
result = pinecone_db.get("id1")
|
||||
assert result is not None
|
||||
assert result.id == "id1"
|
||||
assert result.payload == {"name": "vector1"}
|
||||
|
||||
def test_delete_vector(pinecone_db):
|
||||
pinecone_db.delete("id1")
|
||||
pinecone_db.index.delete.assert_called_with(ids=["id1"])
|
||||
|
||||
def test_get_vector_not_found(pinecone_db):
|
||||
pinecone_db.index.fetch.return_value.vectors = {}
|
||||
result = pinecone_db.get("id1")
|
||||
assert result is None
|
||||
|
||||
def test_list_cols(pinecone_db):
|
||||
pinecone_db.list_cols()
|
||||
pinecone_db.client.list_indexes.assert_called()
|
||||
|
||||
def test_delete_col(pinecone_db):
|
||||
pinecone_db.delete_col()
|
||||
pinecone_db.client.delete_index.assert_called_with("test_index")
|
||||
|
||||
def test_col_info(pinecone_db):
|
||||
pinecone_db.col_info()
|
||||
pinecone_db.client.describe_index.assert_called_with("test_index")
|
||||
Reference in New Issue
Block a user