[Improvement] update pinecone client v3 (#1200)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-26 09:08:37 +05:30
committed by GitHub
parent d2a5b50ff8
commit e75c05112e
6 changed files with 290 additions and 209 deletions

View File

@@ -167,7 +167,7 @@ Install pinecone related dependencies using the following command:
pip install --upgrade 'embedchain[pinecone]' pip install --upgrade 'embedchain[pinecone]'
``` ```
In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/). In order to use Pinecone as vector database, set the environment variable `PINECONE_API_KEY` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
<CodeGroup> <CodeGroup>
@@ -175,23 +175,44 @@ In order to use Pinecone as vector database, set the environment variables `PINE
from embedchain import App from embedchain import App
# load pinecone configuration from yaml file # load pinecone configuration from yaml file
app = App.from_config(config_path="config.yaml") app = App.from_config(config_path="pod_config.yaml")
# or
app = App.from_config(config_path="serverless_config.yaml")
``` ```
```yaml config.yaml ```yaml pod_config.yaml
vectordb: vectordb:
provider: pinecone provider: pinecone
config: config:
metric: cosine metric: cosine
vector_dimension: 1536 vector_dimension: 1536
collection_name: my-pinecone-index collection_name: my-pinecone-index
pod_config:
environment: gcp-starter
metadata_config:
indexed:
- "url"
- "hash"
```
```yaml serverless_config.yaml
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
serverless_config:
cloud: aws
region: us-west-2
``` ```
</CodeGroup> </CodeGroup>
<br /> <br />
<Note> <Note>
You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`. You can find more information about Pinecone configuration [here](https://docs.pinecone.io/docs/manage-indexes#create-a-pod-based-index).
You can also optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
</Note> </Note>
## Qdrant ## Qdrant

View File

@@ -1,3 +1,4 @@
import os
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
@@ -9,14 +10,29 @@ class PineconeDBConfig(BaseVectorDbConfig):
def __init__( def __init__(
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
api_key: Optional[str] = None,
index_name: Optional[str] = None, index_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
vector_dimension: int = 1536, vector_dimension: int = 1536,
metric: Optional[str] = "cosine", metric: Optional[str] = "cosine",
pod_config: Optional[dict[str, any]] = None,
serverless_config: Optional[dict[str, any]] = None,
**extra_params: dict[str, any], **extra_params: dict[str, any],
): ):
self.metric = metric self.metric = metric
self.api_key = api_key
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension
self.extra_params = extra_params self.extra_params = extra_params
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-") self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
super().__init__(collection_name=collection_name, dir=dir) if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
else:
self.pod_config = pod_config
self.serverless_config = serverless_config
if self.pod_config and self.serverless_config:
raise ValueError("Only one of pod_config or serverless_config can be provided.")
super().__init__(collection_name=collection_name, dir=None)

View File

@@ -42,7 +42,7 @@ class PineconeDB(BaseVectorDB):
"Please make sure the type is right and that you are passing an instance." "Please make sure the type is right and that you are passing an instance."
) )
self.config = config self.config = config
self.client = self._setup_pinecone_index() self._setup_pinecone_index()
# Call parent init here because embedder is needed # Call parent init here because embedder is needed
super().__init__(config=self.config) super().__init__(config=self.config)
@@ -57,17 +57,26 @@ class PineconeDB(BaseVectorDB):
""" """
Loads the Pinecone index or creates it if not present. Loads the Pinecone index or creates it if not present.
""" """
pinecone.init( api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
api_key=os.environ.get("PINECONE_API_KEY"), if not api_key:
environment=os.environ.get("PINECONE_ENV"), raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
**self.config.extra_params, self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
) indexes = self.client.list_indexes().names()
indexes = pinecone.list_indexes()
if indexes is None or self.config.index_name not in indexes: if indexes is None or self.config.index_name not in indexes:
pinecone.create_index( if self.config.pod_config:
name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension spec = pinecone.PodSpec(**self.config.pod_config)
elif self.config.serverless_config:
spec = pinecone.ServerlessSpec(**self.config.serverless_config)
else:
raise ValueError("No pod_config or serverless_config found.")
self.client.create_index(
name=self.config.index_name,
metric=self.config.metric,
dimension=self.config.vector_dimension,
spec=spec,
) )
return pinecone.Index(self.config.index_name) self.pinecone_index = self.client.Index(self.config.index_name)
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
@@ -85,7 +94,7 @@ class PineconeDB(BaseVectorDB):
if ids is not None: if ids is not None:
for i in range(0, len(ids), 1000): for i in range(0, len(ids), 1000):
result = self.client.fetch(ids=ids[i : i + 1000]) result = self.pinecone_index.fetch(ids=ids[i : i + 1000])
vectors = result.get("vectors") vectors = result.get("vectors")
batch_existing_ids = list(vectors.keys()) batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids) existing_ids.extend(batch_existing_ids)
@@ -125,7 +134,7 @@ class PineconeDB(BaseVectorDB):
) )
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"): for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
self.client.upsert(chunk, **kwargs) self.pinecone_index.upsert(chunk, **kwargs)
def query( def query(
self, self,
@@ -151,15 +160,19 @@ class PineconeDB(BaseVectorDB):
""" """
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
query_filter = self._generate_filter(where) query_filter = self._generate_filter(where)
data = self.client.query( data = self.pinecone_index.query(
vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs vector=query_vector,
filter=query_filter,
top_k=n_results,
include_metadata=True,
**kwargs,
) )
contexts = [] contexts = []
for doc in data["matches"]: for doc in data.get("matches", []):
metadata = doc["metadata"] metadata = doc.get("metadata", {})
context = metadata["text"] context = metadata.get("text")
if citations: if citations:
metadata["score"] = doc["score"] metadata["score"] = doc.get("score")
contexts.append(tuple((context, metadata))) contexts.append(tuple((context, metadata)))
else: else:
contexts.append(context) contexts.append(context)
@@ -183,7 +196,8 @@ class PineconeDB(BaseVectorDB):
:return: number of documents :return: number of documents
:rtype: int :rtype: int
""" """
return self.client.describe_index_stats()["total_vector_count"] data = self.pinecone_index.describe_index_stats()
return data["total_vector_count"]
def _get_or_create_db(self): def _get_or_create_db(self):
"""Called during initialization""" """Called during initialization"""
@@ -194,7 +208,7 @@ class PineconeDB(BaseVectorDB):
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.
""" """
# Delete all data from the database # Delete all data from the database
pinecone.delete_index(self.config.index_name) self.client.delete_index(self.config.index_name)
self._setup_pinecone_index() self._setup_pinecone_index()
@staticmethod @staticmethod
@@ -213,7 +227,7 @@ class PineconeDB(BaseVectorDB):
# Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
db_filter = self._generate_filter(where) db_filter = self._generate_filter(where)
try: try:
self.client.delete(filter=db_filter) self.pinecone_index.delete(filter=db_filter)
except Exception as e: except Exception as e:
print(f"Failed to delete from Pinecone: {e}") print(f"Failed to delete from Pinecone: {e}")
return return

70
poetry.lock generated
View File

@@ -1260,25 +1260,6 @@ files = [
{file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"},
] ]
[[package]]
name = "dnspython"
version = "2.4.2"
description = "DNS toolkit"
optional = true
python-versions = ">=3.8,<4.0"
files = [
{file = "dnspython-2.4.2-py3-none-any.whl", hash = "sha256:57c6fbaaeaaf39c891292012060beb141791735dbb4004798328fc2c467402d8"},
{file = "dnspython-2.4.2.tar.gz", hash = "sha256:8dcfae8c7460a2f84b4072e26f1c9f4101ca20c071649cb7c34e8b6a93d58984"},
]
[package.extras]
dnssec = ["cryptography (>=2.6,<42.0)"]
doh = ["h2 (>=4.1.0)", "httpcore (>=0.17.3)", "httpx (>=0.24.1)"]
doq = ["aioquic (>=0.9.20)"]
idna = ["idna (>=2.1,<4.0)"]
trio = ["trio (>=0.14,<0.23)"]
wmi = ["wmi (>=1.5.1,<2.0.0)"]
[[package]] [[package]]
name = "docx2txt" name = "docx2txt"
version = "0.8" version = "0.8"
@@ -3204,24 +3185,6 @@ files = [
{file = "lit-17.0.2.tar.gz", hash = "sha256:d6a551eab550f81023c82a260cd484d63970d2be9fd7588111208e7d2ff62212"}, {file = "lit-17.0.2.tar.gz", hash = "sha256:d6a551eab550f81023c82a260cd484d63970d2be9fd7588111208e7d2ff62212"},
] ]
[[package]]
name = "loguru"
version = "0.7.2"
description = "Python logging made (stupidly) simple"
optional = true
python-versions = ">=3.5"
files = [
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
]
[package.dependencies]
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras]
dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
[[package]] [[package]]
name = "lxml" name = "lxml"
version = "4.9.3" version = "4.9.3"
@@ -4680,25 +4643,20 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa
[[package]] [[package]]
name = "pinecone-client" name = "pinecone-client"
version = "2.2.4" version = "3.0.1"
description = "Pinecone client and SDK" description = "Pinecone client and SDK"
optional = true optional = true
python-versions = ">=3.8" python-versions = ">=3.8,<3.13"
files = [ files = [
{file = "pinecone-client-2.2.4.tar.gz", hash = "sha256:2c1cc1d6648b2be66e944db2ffa59166a37b9164d1135ad525d9cd8b1e298168"}, {file = "pinecone_client-3.0.1-py3-none-any.whl", hash = "sha256:c9bb21c23a9088c6198c839be5538ed3f733d152d5fbeaafcc020c1b70b62c2d"},
{file = "pinecone_client-2.2.4-py3-none-any.whl", hash = "sha256:5bf496c01c2f82f4e5c2dc977cc5062ecd7168b8ed90743b09afcc8c7eb242ec"}, {file = "pinecone_client-3.0.1.tar.gz", hash = "sha256:626a0055852c88f1462fc2e132f21d2b078f9a0a74c70b17fe07df3081c6615f"},
] ]
[package.dependencies] [package.dependencies]
dnspython = ">=2.0.0" certifi = ">=2019.11.17"
loguru = ">=0.5.0"
numpy = ">=1.22.0"
python-dateutil = ">=2.5.3"
pyyaml = ">=5.4"
requests = ">=2.19.0"
tqdm = ">=4.64.1" tqdm = ">=4.64.1"
typing-extensions = ">=3.7.4" typing-extensions = ">=3.7.4"
urllib3 = ">=1.21.1" urllib3 = ">=1.26.0"
[package.extras] [package.extras]
grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"] grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"]
@@ -8019,20 +7977,6 @@ files = [
[package.extras] [package.extras]
test = ["pytest (>=6.0.0)", "setuptools (>=65)"] test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
[[package]]
name = "win32-setctime"
version = "1.1.0"
description = "A small Python utility to set file creation time on Windows"
optional = true
python-versions = ">=3.5"
files = [
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
]
[package.extras]
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]] [[package]]
name = "wrapt" name = "wrapt"
version = "1.15.0" version = "1.15.0"
@@ -8316,4 +8260,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<3.12" python-versions = ">=3.9,<3.12"
content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7" content-hash = "a16addd3362ae70c79b15677c6815f708677f11f636093f4e1f5084ba44b5a36"

View File

@@ -123,7 +123,7 @@ cohere = { version = "^4.27", optional = true }
together = { version = "^0.2.8", optional = true } together = { version = "^0.2.8", optional = true }
weaviate-client = { version = "^3.24.1", optional = true } weaviate-client = { version = "^3.24.1", optional = true }
docx2txt = { version = "^0.8", optional = true } docx2txt = { version = "^0.8", optional = true }
pinecone-client = { version = "^2.2.4", optional = true } pinecone-client = { version = "^3.0.0", optional = true }
qdrant-client = { version = "1.6.3", optional = true } qdrant-client = { version = "1.6.3", optional = true }
unstructured = {extras = ["local-inference", "all-docs"], version = "^0.10.18", optional = true} unstructured = {extras = ["local-inference", "all-docs"], version = "^0.10.18", optional = true}
huggingface_hub = { version = "^0.17.3", optional = true } huggingface_hub = { version = "^0.17.3", optional = true }

View File

@@ -1,139 +1,225 @@
from unittest import mock import pytest
from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pinecone import PineconeDB from embedchain.vectordb.pinecone import PineconeDB
class TestPinecone: @pytest.fixture
@patch("embedchain.vectordb.pinecone.pinecone") def pinecone_pod_config():
def test_init(self, pinecone_mock): return PineconeDBConfig(
"""Test that the PineconeDB can be initialized.""" collection_name="test_collection",
# Create a PineconeDB instance api_key="test_api_key",
PineconeDB() vector_dimension=3,
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
)
# Assert that the Pinecone client was initialized
pinecone_mock.init.assert_called_once()
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.Index.assert_called_once()
@patch("embedchain.vectordb.pinecone.pinecone") @pytest.fixture
def test_set_embedder(self, pinecone_mock): def pinecone_serverless_config():
"""Test that the embedder can be set.""" return PineconeDBConfig(
collection_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
serverless_config={
"cloud": "test_cloud",
"region": "test_region",
},
)
# Set the embedder
embedder = BaseEmbedder()
# Create a PineconeDB instance def test_pinecone_init_without_config(monkeypatch):
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB()
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
monkeypatch.delenv("PINECONE_API_KEY")
def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
pinecone_db = PineconeDB(config=pinecone_pod_config)
assert isinstance(pinecone_db, PineconeDB)
assert isinstance(pinecone_db.config, PineconeDBConfig)
assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
class MockListIndexes:
def names(self):
return ["test_collection"]
class MockPineconeIndex:
db = []
def __init__(*args, **kwargs):
pass
def upsert(self, chunk, **kwargs):
self.db.extend([c for c in chunk])
return
def delete(self, *args, **kwargs):
pass
def query(self, *args, **kwargs):
return {
"matches": [
{
"metadata": {
"key": "value",
"text": "text_1",
},
"score": 0.1,
},
{
"metadata": {
"key": "value",
"text": "text_2",
},
"score": 0.2,
},
]
}
def fetch(self, *args, **kwargs):
return {
"vectors": {
"key_1": {
"metadata": {
"source": "1",
}
},
"key_2": {
"metadata": {
"source": "2",
}
},
}
}
def describe_index_stats(self, *args, **kwargs):
return {"total_vector_count": len(self.db)}
class MockPineconeClient:
def __init__(*args, **kwargs):
pass
def list_indexes(self):
return MockListIndexes()
def create_index(self, *args, **kwargs):
pass
def Index(self, *args, **kwargs):
return MockPineconeIndex()
def delete_index(self, *args, **kwargs):
pass
class MockPinecone:
def __init__(*args, **kwargs):
pass
def Pinecone(*args, **kwargs):
return MockPineconeClient()
def PodSpec(*args, **kwargs):
pass
def ServerlessSpec(*args, **kwargs):
pass
class MockEmbedder:
def embedding_fn(self, documents):
return [[1, 1, 1] for d in documents]
def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
pinecone_db = PineconeDB(config=pinecone_pod_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
pinecone_db = PineconeDB(config=pinecone_serverless_config)
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
def test_get(monkeypatch):
def mock_pinecone_db():
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB() db = PineconeDB()
app_config = AppConfig(collect_metrics=False) db.pinecone_index = MockPineconeIndex()
App(config=app_config, db=db, embedding_model=embedder) return db
# Assert that the embedder was set pinecone_db = mock_pinecone_db()
assert db.embedder == embedder ids = pinecone_db.get(["key_1", "key_2"])
pinecone_mock.init.assert_called_once() assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
@patch("embedchain.vectordb.pinecone.pinecone")
def test_add_documents(self, pinecone_mock):
"""Test that documents can be added to the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock() def test_add(monkeypatch):
base_embedder = BaseEmbedder() def mock_pinecone_db():
base_embedder.set_embedding_fn(embedding_function) monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
embedding_function.return_value = [[0, 0, 0], [1, 1, 1]] monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
# Create a PineconeDb instance
db = PineconeDB() db = PineconeDB()
app_config = AppConfig(collect_metrics=False) db.pinecone_index = MockPineconeIndex()
App(config=app_config, db=db, embedding_model=base_embedder) db._set_embedder(MockEmbedder())
return db
# Add some documents to the database pinecone_db = mock_pinecone_db()
documents = ["This is a document.", "This is another document."] pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
metadatas = [{}, {}] assert pinecone_db.count() == 2
ids = ["doc1", "doc2"]
db.add(documents, metadatas, ids)
expected_pinecone_upsert_args = [ pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}}, assert pinecone_db.count() == 4
{"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
]
# Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
@patch("embedchain.vectordb.pinecone.pinecone")
def test_query_documents(self, pinecone_mock):
"""Test that documents can be queried from the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock() def test_query(monkeypatch):
base_embedder = BaseEmbedder() def mock_pinecone_db():
base_embedder.set_embedding_fn(embedding_function) monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
vectors = [[0, 0, 0]] monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
embedding_function.return_value = vectors monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
# Create a PineconeDB instance
db = PineconeDB() db = PineconeDB()
app_config = AppConfig(collect_metrics=False) db.pinecone_index = MockPineconeIndex()
App(config=app_config, db=db, embedding_model=base_embedder) db._set_embedder(MockEmbedder())
return db
# Query the database for documents that are similar to "document" pinecone_db = mock_pinecone_db()
input_query = ["document"] # without citations
n_results = 1 results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
db.query(input_query, n_results, where={}) assert results == ["text_1", "text_2"]
# with citations
# Assert that the Pinecone client was called to query the database results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
pinecone_client_mock.query.assert_called_once_with( assert results == [
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True ("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
) ("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
]
@patch("embedchain.vectordb.pinecone.pinecone")
def test_reset(self, pinecone_mock):
"""Test that the database can be reset."""
# Create a PineconeDb instance
db = PineconeDB()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedding_model=BaseEmbedder())
# Reset the database
db.reset()
# Assert that the Pinecone client was called to delete the index
pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
# Assert that the index is recreated
pinecone_mock.Index.assert_called_with(db.config.index_name)
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_if_it_exists(self, pinecone_mock):
"""Tests custom index name is used if it exists"""
pinecone_mock.list_indexes.return_value = ["custom_index_name"]
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_not_called()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_creation(self, pinecone_mock):
"""Test custom index name is created if it doesn't exists already"""
pinecone_mock.list_indexes.return_value = []
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_default_index_name_is_used(self, pinecone_mock):
"""Test default index name is used if custom index name is not provided"""
db_config = PineconeDBConfig(collection_name="my-collection")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")