[Improvement] update pinecone client v3 (#1200)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-26 09:08:37 +05:30
committed by GitHub
parent d2a5b50ff8
commit e75c05112e
6 changed files with 290 additions and 209 deletions

View File

@@ -1,139 +1,225 @@
from unittest import mock
from unittest.mock import patch
import pytest
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pinecone import PineconeDB
class TestPinecone:
@patch("embedchain.vectordb.pinecone.pinecone")
def test_init(self, pinecone_mock):
"""Test that the PineconeDB can be initialized."""
# Create a PineconeDB instance
PineconeDB()
@pytest.fixture
def pinecone_pod_config():
return PineconeDBConfig(
collection_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
)
# Assert that the Pinecone client was initialized
pinecone_mock.init.assert_called_once()
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.Index.assert_called_once()
@patch("embedchain.vectordb.pinecone.pinecone")
def test_set_embedder(self, pinecone_mock):
"""Test that the embedder can be set."""
@pytest.fixture
def pinecone_serverless_config():
return PineconeDBConfig(
collection_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
serverless_config={
"cloud": "test_cloud",
"region": "test_region",
},
)
# Set the embedder
embedder = BaseEmbedder()
# Create a PineconeDB instance
def test_pinecone_init_without_config(monkeypatch):
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB()
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
monkeypatch.delenv("PINECONE_API_KEY")
def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
class MockListIndexes:
def names(self):
return ["test_collection"]
class MockPineconeIndex:
db = []
def __init__(*args, **kwargs):
pass
def upsert(self, chunk, **kwargs):
self.db.extend([c for c in chunk])
return
def delete(self, *args, **kwargs):
pass
def query(self, *args, **kwargs):
return {
"matches": [
{
"metadata": {
"key": "value",
"text": "text_1",
},
"score": 0.1,
},
{
"metadata": {
"key": "value",
"text": "text_2",
},
"score": 0.2,
},
]
}
def fetch(self, *args, **kwargs):
return {
"vectors": {
"key_1": {
"metadata": {
"source": "1",
}
},
"key_2": {
"metadata": {
"source": "2",
}
},
}
}
def describe_index_stats(self, *args, **kwargs):
return {"total_vector_count": len(self.db)}
class MockPineconeClient:
def __init__(*args, **kwargs):
pass
def list_indexes(self):
return MockListIndexes()
def create_index(self, *args, **kwargs):
pass
def Index(self, *args, **kwargs):
return MockPineconeIndex()
def delete_index(self, *args, **kwargs):
pass
class MockPinecone:
def __init__(*args, **kwargs):
pass
def Pinecone(*args, **kwargs):
return MockPineconeClient()
def PodSpec(*args, **kwargs):
pass
def ServerlessSpec(*args, **kwargs):
pass
class MockEmbedder:
def embedding_fn(self, documents):
return [[1, 1, 1] for d in documents]
def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
pinecone_db = PineconeDB(config=pinecone_pod_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
pinecone_db = PineconeDB(config=pinecone_serverless_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
def test_get(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
db.pinecone_index = MockPineconeIndex()
return db
# Assert that the embedder was set
assert db.embedder == embedder
pinecone_mock.init.assert_called_once()
pinecone_db = mock_pinecone_db()
ids = pinecone_db.get(["key_1", "key_2"])
assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
@patch("embedchain.vectordb.pinecone.pinecone")
def test_add_documents(self, pinecone_mock):
"""Test that documents can be added to the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock()
base_embedder = BaseEmbedder()
base_embedder.set_embedding_fn(embedding_function)
embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
# Create a PineconeDb instance
def test_add(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=base_embedder)
db.pinecone_index = MockPineconeIndex()
db._set_embedder(MockEmbedder())
return db
# Add some documents to the database
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc1", "doc2"]
db.add(documents, metadatas, ids)
pinecone_db = mock_pinecone_db()
pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
assert pinecone_db.count() == 2
expected_pinecone_upsert_args = [
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
{"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
]
# Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
assert pinecone_db.count() == 4
@patch("embedchain.vectordb.pinecone.pinecone")
def test_query_documents(self, pinecone_mock):
"""Test that documents can be queried from the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock()
base_embedder = BaseEmbedder()
base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0]]
embedding_function.return_value = vectors
# Create a PineconeDB instance
def test_query(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=base_embedder)
db.pinecone_index = MockPineconeIndex()
db._set_embedder(MockEmbedder())
return db
# Query the database for documents that are similar to "document"
input_query = ["document"]
n_results = 1
db.query(input_query, n_results, where={})
# Assert that the Pinecone client was called to query the database
pinecone_client_mock.query.assert_called_once_with(
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
)
@patch("embedchain.vectordb.pinecone.pinecone")
def test_reset(self, pinecone_mock):
"""Test that the database can be reset."""
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=BaseEmbedder())
# Reset the database
db.reset()
# Assert that the Pinecone client was called to delete the index
pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
# Assert that the index is recreated
pinecone_mock.Index.assert_called_with(db.config.index_name)
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_if_it_exists(self, pinecone_mock):
"""Tests custom index name is used if it exists"""
pinecone_mock.list_indexes.return_value = ["custom_index_name"]
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_not_called()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_creation(self, pinecone_mock):
"""Test custom index name is created if it doesn't exists already"""
pinecone_mock.list_indexes.return_value = []
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_default_index_name_is_used(self, pinecone_mock):
"""Test default index name is used if custom index name is not provided"""
db_config = PineconeDBConfig(collection_name="my-collection")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")
pinecone_db = mock_pinecone_db()
# without citations
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
assert results == ["text_1", "text_2"]
# with citations
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
assert results == [
("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
]