[Improvement] update pinecone client v3 (#1200)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -1,139 +1,225 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.pinecone import PineconeDB
|
||||
|
||||
|
||||
class TestPinecone:
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_init(self, pinecone_mock):
|
||||
"""Test that the PineconeDB can be initialized."""
|
||||
# Create a PineconeDB instance
|
||||
PineconeDB()
|
||||
@pytest.fixture
|
||||
def pinecone_pod_config():
|
||||
return PineconeDBConfig(
|
||||
collection_name="test_collection",
|
||||
api_key="test_api_key",
|
||||
vector_dimension=3,
|
||||
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
|
||||
)
|
||||
|
||||
# Assert that the Pinecone client was initialized
|
||||
pinecone_mock.init.assert_called_once()
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_set_embedder(self, pinecone_mock):
|
||||
"""Test that the embedder can be set."""
|
||||
@pytest.fixture
|
||||
def pinecone_serverless_config():
|
||||
return PineconeDBConfig(
|
||||
collection_name="test_collection",
|
||||
api_key="test_api_key",
|
||||
vector_dimension=3,
|
||||
serverless_config={
|
||||
"cloud": "test_cloud",
|
||||
"region": "test_region",
|
||||
},
|
||||
)
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
|
||||
# Create a PineconeDB instance
|
||||
def test_pinecone_init_without_config(monkeypatch):
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
pinecone_db = PineconeDB()
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
|
||||
monkeypatch.delenv("PINECONE_API_KEY")
|
||||
|
||||
|
||||
def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
|
||||
assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
|
||||
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
|
||||
assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
|
||||
|
||||
|
||||
class MockListIndexes:
|
||||
def names(self):
|
||||
return ["test_collection"]
|
||||
|
||||
|
||||
class MockPineconeIndex:
|
||||
db = []
|
||||
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def upsert(self, chunk, **kwargs):
|
||||
self.db.extend([c for c in chunk])
|
||||
return
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
return {
|
||||
"matches": [
|
||||
{
|
||||
"metadata": {
|
||||
"key": "value",
|
||||
"text": "text_1",
|
||||
},
|
||||
"score": 0.1,
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"key": "value",
|
||||
"text": "text_2",
|
||||
},
|
||||
"score": 0.2,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def fetch(self, *args, **kwargs):
|
||||
return {
|
||||
"vectors": {
|
||||
"key_1": {
|
||||
"metadata": {
|
||||
"source": "1",
|
||||
}
|
||||
},
|
||||
"key_2": {
|
||||
"metadata": {
|
||||
"source": "2",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def describe_index_stats(self, *args, **kwargs):
|
||||
return {"total_vector_count": len(self.db)}
|
||||
|
||||
|
||||
class MockPineconeClient:
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def list_indexes(self):
|
||||
return MockListIndexes()
|
||||
|
||||
def create_index(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def Index(self, *args, **kwargs):
|
||||
return MockPineconeIndex()
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockPinecone:
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def Pinecone(*args, **kwargs):
|
||||
return MockPineconeClient()
|
||||
|
||||
def PodSpec(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def ServerlessSpec(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockEmbedder:
|
||||
def embedding_fn(self, documents):
|
||||
return [[1, 1, 1] for d in documents]
|
||||
|
||||
|
||||
def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
pinecone_db._setup_pinecone_index()
|
||||
|
||||
assert pinecone_db.client is not None
|
||||
assert pinecone_db.config.index_name == "test-collection-3"
|
||||
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
||||
assert pinecone_db.pinecone_index is not None
|
||||
|
||||
pinecone_db = PineconeDB(config=pinecone_serverless_config)
|
||||
pinecone_db._setup_pinecone_index()
|
||||
|
||||
assert pinecone_db.client is not None
|
||||
assert pinecone_db.config.index_name == "test-collection-3"
|
||||
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
||||
assert pinecone_db.pinecone_index is not None
|
||||
|
||||
|
||||
def test_get(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
return db
|
||||
|
||||
# Assert that the embedder was set
|
||||
assert db.embedder == embedder
|
||||
pinecone_mock.init.assert_called_once()
|
||||
pinecone_db = mock_pinecone_db()
|
||||
ids = pinecone_db.get(["key_1", "key_2"])
|
||||
assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_add_documents(self, pinecone_mock):
|
||||
"""Test that documents can be added to the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
|
||||
embedding_function = mock.Mock()
|
||||
base_embedder = BaseEmbedder()
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
|
||||
|
||||
# Create a PineconeDb instance
|
||||
def test_add(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
db._set_embedder(MockEmbedder())
|
||||
return db
|
||||
|
||||
# Add some documents to the database
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc1", "doc2"]
|
||||
db.add(documents, metadatas, ids)
|
||||
pinecone_db = mock_pinecone_db()
|
||||
pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
|
||||
assert pinecone_db.count() == 2
|
||||
|
||||
expected_pinecone_upsert_args = [
|
||||
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
|
||||
{"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
|
||||
]
|
||||
# Assert that the Pinecone client was called to upsert the documents
|
||||
pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
|
||||
pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
|
||||
assert pinecone_db.count() == 4
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_query_documents(self, pinecone_mock):
|
||||
"""Test that documents can be queried from the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
|
||||
embedding_function = mock.Mock()
|
||||
base_embedder = BaseEmbedder()
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
vectors = [[0, 0, 0]]
|
||||
embedding_function.return_value = vectors
|
||||
# Create a PineconeDB instance
|
||||
def test_query(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
db._set_embedder(MockEmbedder())
|
||||
return db
|
||||
|
||||
# Query the database for documents that are similar to "document"
|
||||
input_query = ["document"]
|
||||
n_results = 1
|
||||
db.query(input_query, n_results, where={})
|
||||
|
||||
# Assert that the Pinecone client was called to query the database
|
||||
pinecone_client_mock.query.assert_called_once_with(
|
||||
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_reset(self, pinecone_mock):
|
||||
"""Test that the database can be reset."""
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=BaseEmbedder())
|
||||
|
||||
# Reset the database
|
||||
db.reset()
|
||||
|
||||
# Assert that the Pinecone client was called to delete the index
|
||||
pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
|
||||
|
||||
# Assert that the index is recreated
|
||||
pinecone_mock.Index.assert_called_with(db.config.index_name)
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_custom_index_name_if_it_exists(self, pinecone_mock):
|
||||
"""Tests custom index name is used if it exists"""
|
||||
pinecone_mock.list_indexes.return_value = ["custom_index_name"]
|
||||
db_config = PineconeDBConfig(index_name="custom_index_name")
|
||||
_ = PineconeDB(config=db_config)
|
||||
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.create_index.assert_not_called()
|
||||
pinecone_mock.Index.assert_called_with("custom_index_name")
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_custom_index_name_creation(self, pinecone_mock):
|
||||
"""Test custom index name is created if it doesn't exists already"""
|
||||
pinecone_mock.list_indexes.return_value = []
|
||||
db_config = PineconeDBConfig(index_name="custom_index_name")
|
||||
_ = PineconeDB(config=db_config)
|
||||
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.create_index.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_with("custom_index_name")
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_default_index_name_is_used(self, pinecone_mock):
|
||||
"""Test default index name is used if custom index name is not provided"""
|
||||
db_config = PineconeDBConfig(collection_name="my-collection")
|
||||
_ = PineconeDB(config=db_config)
|
||||
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.create_index.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")
|
||||
pinecone_db = mock_pinecone_db()
|
||||
# without citations
|
||||
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
|
||||
assert results == ["text_1", "text_2"]
|
||||
# with citations
|
||||
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
|
||||
assert results == [
|
||||
("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
|
||||
("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user