[Improvement] update pinecone client v3 (#1200)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-26 09:08:37 +05:30
committed by GitHub
parent d2a5b50ff8
commit e75c05112e
6 changed files with 290 additions and 209 deletions

View File

@@ -167,7 +167,7 @@ Install pinecone related dependencies using the following command:
pip install --upgrade 'embedchain[pinecone]'
```
In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
In order to use Pinecone as vector database, set the environment variable `PINECONE_API_KEY` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
<CodeGroup>
@@ -175,23 +175,44 @@ In order to use Pinecone as vector database, set the environment variables `PINE
from embedchain import App
# load pinecone configuration from yaml file
app = App.from_config(config_path="config.yaml")
app = App.from_config(config_path="pod_config.yaml")
# or
app = App.from_config(config_path="serverless_config.yaml")
```
```yaml config.yaml
```yaml pod_config.yaml
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
pod_config:
environment: gcp-starter
metadata_config:
indexed:
- "url"
- "hash"
```
```yaml serverless_config.yaml
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
serverless_config:
cloud: aws
region: us-west-2
```
</CodeGroup>
<br />
<Note>
You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
You can find more information about Pinecone configuration [here](https://docs.pinecone.io/docs/manage-indexes#create-a-pod-based-index).
You can also optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
</Note>
## Qdrant

View File

@@ -1,3 +1,4 @@
import os
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
@@ -9,14 +10,29 @@ class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
api_key: Optional[str] = None,
index_name: Optional[str] = None,
dir: Optional[str] = None,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
pod_config: Optional[dict[str, any]] = None,
serverless_config: Optional[dict[str, any]] = None,
**extra_params: dict[str, any],
):
self.metric = metric
self.api_key = api_key
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
super().__init__(collection_name=collection_name, dir=dir)
if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
else:
self.pod_config = pod_config
self.serverless_config = serverless_config
if self.pod_config and self.serverless_config:
raise ValueError("Only one of pod_config or serverless_config can be provided.")
super().__init__(collection_name=collection_name, dir=None)

View File

@@ -42,7 +42,7 @@ class PineconeDB(BaseVectorDB):
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.client = self._setup_pinecone_index()
self._setup_pinecone_index()
# Call parent init here because embedder is needed
super().__init__(config=self.config)
@@ -57,17 +57,26 @@ class PineconeDB(BaseVectorDB):
"""
Loads the Pinecone index or creates it if not present.
"""
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
**self.config.extra_params,
)
indexes = pinecone.list_indexes()
api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
if not api_key:
raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
indexes = self.client.list_indexes().names()
if indexes is None or self.config.index_name not in indexes:
pinecone.create_index(
name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
if self.config.pod_config:
spec = pinecone.PodSpec(**self.config.pod_config)
elif self.config.serverless_config:
spec = pinecone.ServerlessSpec(**self.config.serverless_config)
else:
raise ValueError("No pod_config or serverless_config found.")
self.client.create_index(
name=self.config.index_name,
metric=self.config.metric,
dimension=self.config.vector_dimension,
spec=spec,
)
return pinecone.Index(self.config.index_name)
self.pinecone_index = self.client.Index(self.config.index_name)
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
@@ -85,7 +94,7 @@ class PineconeDB(BaseVectorDB):
if ids is not None:
for i in range(0, len(ids), 1000):
result = self.client.fetch(ids=ids[i : i + 1000])
result = self.pinecone_index.fetch(ids=ids[i : i + 1000])
vectors = result.get("vectors")
batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids)
@@ -125,7 +134,7 @@ class PineconeDB(BaseVectorDB):
)
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
self.client.upsert(chunk, **kwargs)
self.pinecone_index.upsert(chunk, **kwargs)
def query(
self,
@@ -151,15 +160,19 @@ class PineconeDB(BaseVectorDB):
"""
query_vector = self.embedder.embedding_fn([input_query])[0]
query_filter = self._generate_filter(where)
data = self.client.query(
vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs
data = self.pinecone_index.query(
vector=query_vector,
filter=query_filter,
top_k=n_results,
include_metadata=True,
**kwargs,
)
contexts = []
for doc in data["matches"]:
metadata = doc["metadata"]
context = metadata["text"]
for doc in data.get("matches", []):
metadata = doc.get("metadata", {})
context = metadata.get("text")
if citations:
metadata["score"] = doc["score"]
metadata["score"] = doc.get("score")
contexts.append(tuple((context, metadata)))
else:
contexts.append(context)
@@ -183,7 +196,8 @@ class PineconeDB(BaseVectorDB):
:return: number of documents
:rtype: int
"""
return self.client.describe_index_stats()["total_vector_count"]
data = self.pinecone_index.describe_index_stats()
return data["total_vector_count"]
def _get_or_create_db(self):
"""Called during initialization"""
@@ -194,7 +208,7 @@ class PineconeDB(BaseVectorDB):
Resets the database. Deletes all embeddings irreversibly.
"""
# Delete all data from the database
pinecone.delete_index(self.config.index_name)
self.client.delete_index(self.config.index_name)
self._setup_pinecone_index()
@staticmethod
@@ -213,7 +227,7 @@ class PineconeDB(BaseVectorDB):
# Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
db_filter = self._generate_filter(where)
try:
self.client.delete(filter=db_filter)
self.pinecone_index.delete(filter=db_filter)
except Exception as e:
print(f"Failed to delete from Pinecone: {e}")
return

70
poetry.lock generated
View File

@@ -1260,25 +1260,6 @@ files = [
{file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"},
]
[[package]]
name = "dnspython"
version = "2.4.2"
description = "DNS toolkit"
optional = true
python-versions = ">=3.8,<4.0"
files = [
{file = "dnspython-2.4.2-py3-none-any.whl", hash = "sha256:57c6fbaaeaaf39c891292012060beb141791735dbb4004798328fc2c467402d8"},
{file = "dnspython-2.4.2.tar.gz", hash = "sha256:8dcfae8c7460a2f84b4072e26f1c9f4101ca20c071649cb7c34e8b6a93d58984"},
]
[package.extras]
dnssec = ["cryptography (>=2.6,<42.0)"]
doh = ["h2 (>=4.1.0)", "httpcore (>=0.17.3)", "httpx (>=0.24.1)"]
doq = ["aioquic (>=0.9.20)"]
idna = ["idna (>=2.1,<4.0)"]
trio = ["trio (>=0.14,<0.23)"]
wmi = ["wmi (>=1.5.1,<2.0.0)"]
[[package]]
name = "docx2txt"
version = "0.8"
@@ -3204,24 +3185,6 @@ files = [
{file = "lit-17.0.2.tar.gz", hash = "sha256:d6a551eab550f81023c82a260cd484d63970d2be9fd7588111208e7d2ff62212"},
]
[[package]]
name = "loguru"
version = "0.7.2"
description = "Python logging made (stupidly) simple"
optional = true
python-versions = ">=3.5"
files = [
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
]
[package.dependencies]
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras]
dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
[[package]]
name = "lxml"
version = "4.9.3"
@@ -4680,25 +4643,20 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa
[[package]]
name = "pinecone-client"
version = "2.2.4"
version = "3.0.1"
description = "Pinecone client and SDK"
optional = true
python-versions = ">=3.8"
python-versions = ">=3.8,<3.13"
files = [
{file = "pinecone-client-2.2.4.tar.gz", hash = "sha256:2c1cc1d6648b2be66e944db2ffa59166a37b9164d1135ad525d9cd8b1e298168"},
{file = "pinecone_client-2.2.4-py3-none-any.whl", hash = "sha256:5bf496c01c2f82f4e5c2dc977cc5062ecd7168b8ed90743b09afcc8c7eb242ec"},
{file = "pinecone_client-3.0.1-py3-none-any.whl", hash = "sha256:c9bb21c23a9088c6198c839be5538ed3f733d152d5fbeaafcc020c1b70b62c2d"},
{file = "pinecone_client-3.0.1.tar.gz", hash = "sha256:626a0055852c88f1462fc2e132f21d2b078f9a0a74c70b17fe07df3081c6615f"},
]
[package.dependencies]
dnspython = ">=2.0.0"
loguru = ">=0.5.0"
numpy = ">=1.22.0"
python-dateutil = ">=2.5.3"
pyyaml = ">=5.4"
requests = ">=2.19.0"
certifi = ">=2019.11.17"
tqdm = ">=4.64.1"
typing-extensions = ">=3.7.4"
urllib3 = ">=1.21.1"
urllib3 = ">=1.26.0"
[package.extras]
grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"]
@@ -8019,20 +7977,6 @@ files = [
[package.extras]
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[[package]]
name = "win32-setctime"
version = "1.1.0"
description = "A small Python utility to set file creation time on Windows"
optional = true
python-versions = ">=3.5"
files = [
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
]
[package.extras]
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]]
name = "wrapt"
version = "1.15.0"
@@ -8316,4 +8260,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7"
content-hash = "a16addd3362ae70c79b15677c6815f708677f11f636093f4e1f5084ba44b5a36"

View File

@@ -123,7 +123,7 @@ cohere = { version = "^4.27", optional = true }
together = { version = "^0.2.8", optional = true }
weaviate-client = { version = "^3.24.1", optional = true }
docx2txt = { version = "^0.8", optional = true }
pinecone-client = { version = "^2.2.4", optional = true }
pinecone-client = { version = "^3.0.0", optional = true }
qdrant-client = { version = "1.6.3", optional = true }
unstructured = {extras = ["local-inference", "all-docs"], version = "^0.10.18", optional = true}
huggingface_hub = { version = "^0.17.3", optional = true }

View File

@@ -1,139 +1,225 @@
from unittest import mock
from unittest.mock import patch
import pytest
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pinecone import PineconeDB
class TestPinecone:
@patch("embedchain.vectordb.pinecone.pinecone")
def test_init(self, pinecone_mock):
"""Test that the PineconeDB can be initialized."""
# Create a PineconeDB instance
PineconeDB()
@pytest.fixture
def pinecone_pod_config():
return PineconeDBConfig(
collection_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
)
# Assert that the Pinecone client was initialized
pinecone_mock.init.assert_called_once()
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.Index.assert_called_once()
@patch("embedchain.vectordb.pinecone.pinecone")
def test_set_embedder(self, pinecone_mock):
"""Test that the embedder can be set."""
@pytest.fixture
def pinecone_serverless_config():
return PineconeDBConfig(
collection_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
serverless_config={
"cloud": "test_cloud",
"region": "test_region",
},
)
# Set the embedder
embedder = BaseEmbedder()
# Create a PineconeDB instance
def test_pinecone_init_without_config(monkeypatch):
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB()
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
monkeypatch.delenv("PINECONE_API_KEY")
def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
class MockListIndexes:
def names(self):
return ["test_collection"]
class MockPineconeIndex:
db = []
def __init__(*args, **kwargs):
pass
def upsert(self, chunk, **kwargs):
self.db.extend([c for c in chunk])
return
def delete(self, *args, **kwargs):
pass
def query(self, *args, **kwargs):
return {
"matches": [
{
"metadata": {
"key": "value",
"text": "text_1",
},
"score": 0.1,
},
{
"metadata": {
"key": "value",
"text": "text_2",
},
"score": 0.2,
},
]
}
def fetch(self, *args, **kwargs):
return {
"vectors": {
"key_1": {
"metadata": {
"source": "1",
}
},
"key_2": {
"metadata": {
"source": "2",
}
},
}
}
def describe_index_stats(self, *args, **kwargs):
return {"total_vector_count": len(self.db)}
class MockPineconeClient:
def __init__(*args, **kwargs):
pass
def list_indexes(self):
return MockListIndexes()
def create_index(self, *args, **kwargs):
pass
def Index(self, *args, **kwargs):
return MockPineconeIndex()
def delete_index(self, *args, **kwargs):
pass
class MockPinecone:
def __init__(*args, **kwargs):
pass
def Pinecone(*args, **kwargs):
return MockPineconeClient()
def PodSpec(*args, **kwargs):
pass
def ServerlessSpec(*args, **kwargs):
pass
class MockEmbedder:
def embedding_fn(self, documents):
return [[1, 1, 1] for d in documents]
def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
pinecone_db = PineconeDB(config=pinecone_pod_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
pinecone_db = PineconeDB(config=pinecone_serverless_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
def test_get(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=embedder)
db.pinecone_index = MockPineconeIndex()
return db
# Assert that the embedder was set
assert db.embedder == embedder
pinecone_mock.init.assert_called_once()
pinecone_db = mock_pinecone_db()
ids = pinecone_db.get(["key_1", "key_2"])
assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
@patch("embedchain.vectordb.pinecone.pinecone")
def test_add_documents(self, pinecone_mock):
"""Test that documents can be added to the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock()
base_embedder = BaseEmbedder()
base_embedder.set_embedding_fn(embedding_function)
embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
# Create a PineconeDb instance
def test_add(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=base_embedder)
db.pinecone_index = MockPineconeIndex()
db._set_embedder(MockEmbedder())
return db
# Add some documents to the database
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc1", "doc2"]
db.add(documents, metadatas, ids)
pinecone_db = mock_pinecone_db()
pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
assert pinecone_db.count() == 2
expected_pinecone_upsert_args = [
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
{"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
]
# Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
assert pinecone_db.count() == 4
@patch("embedchain.vectordb.pinecone.pinecone")
def test_query_documents(self, pinecone_mock):
"""Test that documents can be queried from the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock()
base_embedder = BaseEmbedder()
base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0]]
embedding_function.return_value = vectors
# Create a PineconeDB instance
def test_query(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=base_embedder)
db.pinecone_index = MockPineconeIndex()
db._set_embedder(MockEmbedder())
return db
# Query the database for documents that are similar to "document"
input_query = ["document"]
n_results = 1
db.query(input_query, n_results, where={})
# Assert that the Pinecone client was called to query the database
pinecone_client_mock.query.assert_called_once_with(
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
)
@patch("embedchain.vectordb.pinecone.pinecone")
def test_reset(self, pinecone_mock):
"""Test that the database can be reset."""
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=BaseEmbedder())
# Reset the database
db.reset()
# Assert that the Pinecone client was called to delete the index
pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
# Assert that the index is recreated
pinecone_mock.Index.assert_called_with(db.config.index_name)
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_if_it_exists(self, pinecone_mock):
"""Tests custom index name is used if it exists"""
pinecone_mock.list_indexes.return_value = ["custom_index_name"]
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_not_called()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_creation(self, pinecone_mock):
"""Test custom index name is created if it doesn't exists already"""
pinecone_mock.list_indexes.return_value = []
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_default_index_name_is_used(self, pinecone_mock):
"""Test default index name is used if custom index name is not provided"""
db_config = PineconeDBConfig(collection_name="my-collection")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")
pinecone_db = mock_pinecone_db()
# without citations
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
assert results == ["text_1", "text_2"]
# with citations
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
assert results == [
("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
]