121 lines
3.8 KiB
Python
121 lines
3.8 KiB
Python
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from mem0.vector_stores.pinecone import PineconeDB
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_pinecone_client():
|
|
client = MagicMock()
|
|
client.Index.return_value = MagicMock()
|
|
client.list_indexes.return_value.names.return_value = []
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def pinecone_db(mock_pinecone_client):
|
|
return PineconeDB(
|
|
collection_name="test_index",
|
|
embedding_model_dims=128,
|
|
client=mock_pinecone_client,
|
|
api_key="fake_api_key",
|
|
environment="us-west1-gcp",
|
|
serverless_config=None,
|
|
pod_config=None,
|
|
hybrid_search=False,
|
|
metric="cosine",
|
|
batch_size=100,
|
|
extra_params=None
|
|
)
|
|
|
|
def test_create_col_existing_index(mock_pinecone_client):
|
|
# Set up the mock before creating the PineconeDB object
|
|
mock_pinecone_client.list_indexes.return_value.names.return_value = ["test_index"]
|
|
|
|
pinecone_db = PineconeDB(
|
|
collection_name="test_index",
|
|
embedding_model_dims=128,
|
|
client=mock_pinecone_client,
|
|
api_key="fake_api_key",
|
|
environment="us-west1-gcp",
|
|
serverless_config=None,
|
|
pod_config=None,
|
|
hybrid_search=False,
|
|
metric="cosine",
|
|
batch_size=100,
|
|
extra_params=None
|
|
)
|
|
|
|
# Reset the mock to verify it wasn't called during the test
|
|
mock_pinecone_client.create_index.reset_mock()
|
|
|
|
pinecone_db.create_col(128, "cosine")
|
|
|
|
mock_pinecone_client.create_index.assert_not_called()
|
|
|
|
def test_create_col_new_index(pinecone_db, mock_pinecone_client):
|
|
mock_pinecone_client.list_indexes.return_value.names.return_value = []
|
|
pinecone_db.create_col(128, "cosine")
|
|
mock_pinecone_client.create_index.assert_called()
|
|
|
|
def test_insert_vectors(pinecone_db):
|
|
vectors = [[0.1] * 128, [0.2] * 128]
|
|
payloads = [{"name": "vector1"}, {"name": "vector2"}]
|
|
ids = ["id1", "id2"]
|
|
pinecone_db.insert(vectors, payloads, ids)
|
|
pinecone_db.index.upsert.assert_called()
|
|
|
|
def test_search_vectors(pinecone_db):
|
|
pinecone_db.index.query.return_value.matches = [{"id": "id1", "score": 0.9, "metadata": {"name": "vector1"}}]
|
|
results = pinecone_db.search("test query",[0.1] * 128, limit=1)
|
|
assert len(results) == 1
|
|
assert results[0].id == "id1"
|
|
assert results[0].score == 0.9
|
|
|
|
def test_update_vector(pinecone_db):
|
|
pinecone_db.update("id1", vector=[0.5] * 128, payload={"name": "updated"})
|
|
pinecone_db.index.upsert.assert_called()
|
|
|
|
def test_get_vector_found(pinecone_db):
|
|
# Looking at the _parse_output method, it expects a Vector object
|
|
# or a list of dictionaries, not a dictionary with an 'id' field
|
|
|
|
# Create a mock Vector object
|
|
from pinecone.data.dataclasses.vector import Vector
|
|
mock_vector = Vector(
|
|
id="id1",
|
|
values=[0.1] * 128,
|
|
metadata={"name": "vector1"}
|
|
)
|
|
|
|
# Mock the fetch method to return the mock response object
|
|
mock_response = MagicMock()
|
|
mock_response.vectors = {"id1": mock_vector}
|
|
pinecone_db.index.fetch.return_value = mock_response
|
|
|
|
result = pinecone_db.get("id1")
|
|
assert result is not None
|
|
assert result.id == "id1"
|
|
assert result.payload == {"name": "vector1"}
|
|
|
|
def test_delete_vector(pinecone_db):
|
|
pinecone_db.delete("id1")
|
|
pinecone_db.index.delete.assert_called_with(ids=["id1"])
|
|
|
|
def test_get_vector_not_found(pinecone_db):
|
|
pinecone_db.index.fetch.return_value.vectors = {}
|
|
result = pinecone_db.get("id1")
|
|
assert result is None
|
|
|
|
def test_list_cols(pinecone_db):
|
|
pinecone_db.list_cols()
|
|
pinecone_db.client.list_indexes.assert_called()
|
|
|
|
def test_delete_col(pinecone_db):
|
|
pinecone_db.delete_col()
|
|
pinecone_db.client.delete_index.assert_called_with("test_index")
|
|
|
|
def test_col_info(pinecone_db):
|
|
pinecone_db.col_info()
|
|
pinecone_db.client.describe_index.assert_called_with("test_index")
|