[Improvement] update pinecone client v3 (#1200)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -167,7 +167,7 @@ Install pinecone related dependencies using the following command:
|
||||
pip install --upgrade 'embedchain[pinecone]'
|
||||
```
|
||||
|
||||
In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
|
||||
In order to use Pinecone as vector database, set the environment variable `PINECONE_API_KEY` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
|
||||
|
||||
<CodeGroup>
|
||||
|
||||
@@ -175,23 +175,44 @@ In order to use Pinecone as vector database, set the environment variables `PINE
|
||||
from embedchain import App
|
||||
|
||||
# load pinecone configuration from yaml file
|
||||
app = App.from_config(config_path="config.yaml")
|
||||
app = App.from_config(config_path="pod_config.yaml")
|
||||
# or
|
||||
app = App.from_config(config_path="serverless_config.yaml")
|
||||
```
|
||||
|
||||
```yaml config.yaml
|
||||
```yaml pod_config.yaml
|
||||
vectordb:
|
||||
provider: pinecone
|
||||
config:
|
||||
metric: cosine
|
||||
vector_dimension: 1536
|
||||
collection_name: my-pinecone-index
|
||||
pod_config:
|
||||
environment: gcp-starter
|
||||
metadata_config:
|
||||
indexed:
|
||||
- "url"
|
||||
- "hash"
|
||||
```
|
||||
|
||||
```yaml serverless_config.yaml
|
||||
vectordb:
|
||||
provider: pinecone
|
||||
config:
|
||||
metric: cosine
|
||||
vector_dimension: 1536
|
||||
collection_name: my-pinecone-index
|
||||
serverless_config:
|
||||
cloud: aws
|
||||
region: us-west-2
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
<br />
|
||||
<Note>
|
||||
You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
|
||||
You can find more information about Pinecone configuration [here](https://docs.pinecone.io/docs/manage-indexes#create-a-pod-based-index).
|
||||
You can also optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
|
||||
</Note>
|
||||
|
||||
## Qdrant
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
@@ -9,14 +10,29 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
index_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
vector_dimension: int = 1536,
|
||||
metric: Optional[str] = "cosine",
|
||||
pod_config: Optional[dict[str, any]] = None,
|
||||
serverless_config: Optional[dict[str, any]] = None,
|
||||
**extra_params: dict[str, any],
|
||||
):
|
||||
self.metric = metric
|
||||
self.api_key = api_key
|
||||
self.vector_dimension = vector_dimension
|
||||
self.extra_params = extra_params
|
||||
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
if pod_config is None and serverless_config is None:
|
||||
# If no config is provided, use the default pod spec config
|
||||
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
|
||||
self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
|
||||
else:
|
||||
self.pod_config = pod_config
|
||||
self.serverless_config = serverless_config
|
||||
|
||||
if self.pod_config and self.serverless_config:
|
||||
raise ValueError("Only one of pod_config or serverless_config can be provided.")
|
||||
|
||||
super().__init__(collection_name=collection_name, dir=None)
|
||||
|
||||
@@ -42,7 +42,7 @@ class PineconeDB(BaseVectorDB):
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
self.config = config
|
||||
self.client = self._setup_pinecone_index()
|
||||
self._setup_pinecone_index()
|
||||
# Call parent init here because embedder is needed
|
||||
super().__init__(config=self.config)
|
||||
|
||||
@@ -57,17 +57,26 @@ class PineconeDB(BaseVectorDB):
|
||||
"""
|
||||
Loads the Pinecone index or creates it if not present.
|
||||
"""
|
||||
pinecone.init(
|
||||
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||
environment=os.environ.get("PINECONE_ENV"),
|
||||
**self.config.extra_params,
|
||||
)
|
||||
indexes = pinecone.list_indexes()
|
||||
api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
|
||||
self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
|
||||
indexes = self.client.list_indexes().names()
|
||||
if indexes is None or self.config.index_name not in indexes:
|
||||
pinecone.create_index(
|
||||
name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
|
||||
if self.config.pod_config:
|
||||
spec = pinecone.PodSpec(**self.config.pod_config)
|
||||
elif self.config.serverless_config:
|
||||
spec = pinecone.ServerlessSpec(**self.config.serverless_config)
|
||||
else:
|
||||
raise ValueError("No pod_config or serverless_config found.")
|
||||
|
||||
self.client.create_index(
|
||||
name=self.config.index_name,
|
||||
metric=self.config.metric,
|
||||
dimension=self.config.vector_dimension,
|
||||
spec=spec,
|
||||
)
|
||||
return pinecone.Index(self.config.index_name)
|
||||
self.pinecone_index = self.client.Index(self.config.index_name)
|
||||
|
||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
@@ -85,7 +94,7 @@ class PineconeDB(BaseVectorDB):
|
||||
|
||||
if ids is not None:
|
||||
for i in range(0, len(ids), 1000):
|
||||
result = self.client.fetch(ids=ids[i : i + 1000])
|
||||
result = self.pinecone_index.fetch(ids=ids[i : i + 1000])
|
||||
vectors = result.get("vectors")
|
||||
batch_existing_ids = list(vectors.keys())
|
||||
existing_ids.extend(batch_existing_ids)
|
||||
@@ -125,7 +134,7 @@ class PineconeDB(BaseVectorDB):
|
||||
)
|
||||
|
||||
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
|
||||
self.client.upsert(chunk, **kwargs)
|
||||
self.pinecone_index.upsert(chunk, **kwargs)
|
||||
|
||||
def query(
|
||||
self,
|
||||
@@ -151,15 +160,19 @@ class PineconeDB(BaseVectorDB):
|
||||
"""
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
query_filter = self._generate_filter(where)
|
||||
data = self.client.query(
|
||||
vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs
|
||||
data = self.pinecone_index.query(
|
||||
vector=query_vector,
|
||||
filter=query_filter,
|
||||
top_k=n_results,
|
||||
include_metadata=True,
|
||||
**kwargs,
|
||||
)
|
||||
contexts = []
|
||||
for doc in data["matches"]:
|
||||
metadata = doc["metadata"]
|
||||
context = metadata["text"]
|
||||
for doc in data.get("matches", []):
|
||||
metadata = doc.get("metadata", {})
|
||||
context = metadata.get("text")
|
||||
if citations:
|
||||
metadata["score"] = doc["score"]
|
||||
metadata["score"] = doc.get("score")
|
||||
contexts.append(tuple((context, metadata)))
|
||||
else:
|
||||
contexts.append(context)
|
||||
@@ -183,7 +196,8 @@ class PineconeDB(BaseVectorDB):
|
||||
:return: number of documents
|
||||
:rtype: int
|
||||
"""
|
||||
return self.client.describe_index_stats()["total_vector_count"]
|
||||
data = self.pinecone_index.describe_index_stats()
|
||||
return data["total_vector_count"]
|
||||
|
||||
def _get_or_create_db(self):
|
||||
"""Called during initialization"""
|
||||
@@ -194,7 +208,7 @@ class PineconeDB(BaseVectorDB):
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
"""
|
||||
# Delete all data from the database
|
||||
pinecone.delete_index(self.config.index_name)
|
||||
self.client.delete_index(self.config.index_name)
|
||||
self._setup_pinecone_index()
|
||||
|
||||
@staticmethod
|
||||
@@ -213,7 +227,7 @@ class PineconeDB(BaseVectorDB):
|
||||
# Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
|
||||
db_filter = self._generate_filter(where)
|
||||
try:
|
||||
self.client.delete(filter=db_filter)
|
||||
self.pinecone_index.delete(filter=db_filter)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete from Pinecone: {e}")
|
||||
return
|
||||
|
||||
70
poetry.lock
generated
70
poetry.lock
generated
@@ -1260,25 +1260,6 @@ files = [
|
||||
{file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dnspython"
|
||||
version = "2.4.2"
|
||||
description = "DNS toolkit"
|
||||
optional = true
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "dnspython-2.4.2-py3-none-any.whl", hash = "sha256:57c6fbaaeaaf39c891292012060beb141791735dbb4004798328fc2c467402d8"},
|
||||
{file = "dnspython-2.4.2.tar.gz", hash = "sha256:8dcfae8c7460a2f84b4072e26f1c9f4101ca20c071649cb7c34e8b6a93d58984"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dnssec = ["cryptography (>=2.6,<42.0)"]
|
||||
doh = ["h2 (>=4.1.0)", "httpcore (>=0.17.3)", "httpx (>=0.24.1)"]
|
||||
doq = ["aioquic (>=0.9.20)"]
|
||||
idna = ["idna (>=2.1,<4.0)"]
|
||||
trio = ["trio (>=0.14,<0.23)"]
|
||||
wmi = ["wmi (>=1.5.1,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "docx2txt"
|
||||
version = "0.8"
|
||||
@@ -3204,24 +3185,6 @@ files = [
|
||||
{file = "lit-17.0.2.tar.gz", hash = "sha256:d6a551eab550f81023c82a260cd484d63970d2be9fd7588111208e7d2ff62212"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.7.2"
|
||||
description = "Python logging made (stupidly) simple"
|
||||
optional = true
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
|
||||
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
|
||||
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "lxml"
|
||||
version = "4.9.3"
|
||||
@@ -4680,25 +4643,20 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa
|
||||
|
||||
[[package]]
|
||||
name = "pinecone-client"
|
||||
version = "2.2.4"
|
||||
version = "3.0.1"
|
||||
description = "Pinecone client and SDK"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.8,<3.13"
|
||||
files = [
|
||||
{file = "pinecone-client-2.2.4.tar.gz", hash = "sha256:2c1cc1d6648b2be66e944db2ffa59166a37b9164d1135ad525d9cd8b1e298168"},
|
||||
{file = "pinecone_client-2.2.4-py3-none-any.whl", hash = "sha256:5bf496c01c2f82f4e5c2dc977cc5062ecd7168b8ed90743b09afcc8c7eb242ec"},
|
||||
{file = "pinecone_client-3.0.1-py3-none-any.whl", hash = "sha256:c9bb21c23a9088c6198c839be5538ed3f733d152d5fbeaafcc020c1b70b62c2d"},
|
||||
{file = "pinecone_client-3.0.1.tar.gz", hash = "sha256:626a0055852c88f1462fc2e132f21d2b078f9a0a74c70b17fe07df3081c6615f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
dnspython = ">=2.0.0"
|
||||
loguru = ">=0.5.0"
|
||||
numpy = ">=1.22.0"
|
||||
python-dateutil = ">=2.5.3"
|
||||
pyyaml = ">=5.4"
|
||||
requests = ">=2.19.0"
|
||||
certifi = ">=2019.11.17"
|
||||
tqdm = ">=4.64.1"
|
||||
typing-extensions = ">=3.7.4"
|
||||
urllib3 = ">=1.21.1"
|
||||
urllib3 = ">=1.26.0"
|
||||
|
||||
[package.extras]
|
||||
grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"]
|
||||
@@ -8019,20 +7977,6 @@ files = [
|
||||
[package.extras]
|
||||
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
|
||||
|
||||
[[package]]
|
||||
name = "win32-setctime"
|
||||
version = "1.1.0"
|
||||
description = "A small Python utility to set file creation time on Windows"
|
||||
optional = true
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
|
||||
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.15.0"
|
||||
@@ -8316,4 +8260,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.12"
|
||||
content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7"
|
||||
content-hash = "a16addd3362ae70c79b15677c6815f708677f11f636093f4e1f5084ba44b5a36"
|
||||
|
||||
@@ -123,7 +123,7 @@ cohere = { version = "^4.27", optional = true }
|
||||
together = { version = "^0.2.8", optional = true }
|
||||
weaviate-client = { version = "^3.24.1", optional = true }
|
||||
docx2txt = { version = "^0.8", optional = true }
|
||||
pinecone-client = { version = "^2.2.4", optional = true }
|
||||
pinecone-client = { version = "^3.0.0", optional = true }
|
||||
qdrant-client = { version = "1.6.3", optional = true }
|
||||
unstructured = {extras = ["local-inference", "all-docs"], version = "^0.10.18", optional = true}
|
||||
huggingface_hub = { version = "^0.17.3", optional = true }
|
||||
|
||||
@@ -1,139 +1,225 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.pinecone import PineconeDB
|
||||
|
||||
|
||||
class TestPinecone:
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_init(self, pinecone_mock):
|
||||
"""Test that the PineconeDB can be initialized."""
|
||||
# Create a PineconeDB instance
|
||||
PineconeDB()
|
||||
@pytest.fixture
|
||||
def pinecone_pod_config():
|
||||
return PineconeDBConfig(
|
||||
collection_name="test_collection",
|
||||
api_key="test_api_key",
|
||||
vector_dimension=3,
|
||||
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
|
||||
)
|
||||
|
||||
# Assert that the Pinecone client was initialized
|
||||
pinecone_mock.init.assert_called_once()
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_set_embedder(self, pinecone_mock):
|
||||
"""Test that the embedder can be set."""
|
||||
@pytest.fixture
|
||||
def pinecone_serverless_config():
|
||||
return PineconeDBConfig(
|
||||
collection_name="test_collection",
|
||||
api_key="test_api_key",
|
||||
vector_dimension=3,
|
||||
serverless_config={
|
||||
"cloud": "test_cloud",
|
||||
"region": "test_region",
|
||||
},
|
||||
)
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
|
||||
# Create a PineconeDB instance
|
||||
def test_pinecone_init_without_config(monkeypatch):
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
pinecone_db = PineconeDB()
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
|
||||
monkeypatch.delenv("PINECONE_API_KEY")
|
||||
|
||||
|
||||
def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
|
||||
assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
|
||||
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
|
||||
assert isinstance(pinecone_db, PineconeDB)
|
||||
assert isinstance(pinecone_db.config, PineconeDBConfig)
|
||||
|
||||
assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
|
||||
|
||||
|
||||
class MockListIndexes:
|
||||
def names(self):
|
||||
return ["test_collection"]
|
||||
|
||||
|
||||
class MockPineconeIndex:
|
||||
db = []
|
||||
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def upsert(self, chunk, **kwargs):
|
||||
self.db.extend([c for c in chunk])
|
||||
return
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
return {
|
||||
"matches": [
|
||||
{
|
||||
"metadata": {
|
||||
"key": "value",
|
||||
"text": "text_1",
|
||||
},
|
||||
"score": 0.1,
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"key": "value",
|
||||
"text": "text_2",
|
||||
},
|
||||
"score": 0.2,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def fetch(self, *args, **kwargs):
|
||||
return {
|
||||
"vectors": {
|
||||
"key_1": {
|
||||
"metadata": {
|
||||
"source": "1",
|
||||
}
|
||||
},
|
||||
"key_2": {
|
||||
"metadata": {
|
||||
"source": "2",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def describe_index_stats(self, *args, **kwargs):
|
||||
return {"total_vector_count": len(self.db)}
|
||||
|
||||
|
||||
class MockPineconeClient:
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def list_indexes(self):
|
||||
return MockListIndexes()
|
||||
|
||||
def create_index(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def Index(self, *args, **kwargs):
|
||||
return MockPineconeIndex()
|
||||
|
||||
def delete_index(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockPinecone:
|
||||
def __init__(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def Pinecone(*args, **kwargs):
|
||||
return MockPineconeClient()
|
||||
|
||||
def PodSpec(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def ServerlessSpec(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockEmbedder:
|
||||
def embedding_fn(self, documents):
|
||||
return [[1, 1, 1] for d in documents]
|
||||
|
||||
|
||||
def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
pinecone_db = PineconeDB(config=pinecone_pod_config)
|
||||
pinecone_db._setup_pinecone_index()
|
||||
|
||||
assert pinecone_db.client is not None
|
||||
assert pinecone_db.config.index_name == "test-collection-3"
|
||||
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
||||
assert pinecone_db.pinecone_index is not None
|
||||
|
||||
pinecone_db = PineconeDB(config=pinecone_serverless_config)
|
||||
pinecone_db._setup_pinecone_index()
|
||||
|
||||
assert pinecone_db.client is not None
|
||||
assert pinecone_db.config.index_name == "test-collection-3"
|
||||
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
|
||||
assert pinecone_db.pinecone_index is not None
|
||||
|
||||
|
||||
def test_get(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
return db
|
||||
|
||||
# Assert that the embedder was set
|
||||
assert db.embedder == embedder
|
||||
pinecone_mock.init.assert_called_once()
|
||||
pinecone_db = mock_pinecone_db()
|
||||
ids = pinecone_db.get(["key_1", "key_2"])
|
||||
assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_add_documents(self, pinecone_mock):
|
||||
"""Test that documents can be added to the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
|
||||
embedding_function = mock.Mock()
|
||||
base_embedder = BaseEmbedder()
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
|
||||
|
||||
# Create a PineconeDb instance
|
||||
def test_add(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
db._set_embedder(MockEmbedder())
|
||||
return db
|
||||
|
||||
# Add some documents to the database
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc1", "doc2"]
|
||||
db.add(documents, metadatas, ids)
|
||||
pinecone_db = mock_pinecone_db()
|
||||
pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
|
||||
assert pinecone_db.count() == 2
|
||||
|
||||
expected_pinecone_upsert_args = [
|
||||
{"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
|
||||
{"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
|
||||
]
|
||||
# Assert that the Pinecone client was called to upsert the documents
|
||||
pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
|
||||
pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
|
||||
assert pinecone_db.count() == 4
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_query_documents(self, pinecone_mock):
|
||||
"""Test that documents can be queried from the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
|
||||
embedding_function = mock.Mock()
|
||||
base_embedder = BaseEmbedder()
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
vectors = [[0, 0, 0]]
|
||||
embedding_function.return_value = vectors
|
||||
# Create a PineconeDB instance
|
||||
def test_query(monkeypatch):
|
||||
def mock_pinecone_db():
|
||||
monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
|
||||
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=base_embedder)
|
||||
db.pinecone_index = MockPineconeIndex()
|
||||
db._set_embedder(MockEmbedder())
|
||||
return db
|
||||
|
||||
# Query the database for documents that are similar to "document"
|
||||
input_query = ["document"]
|
||||
n_results = 1
|
||||
db.query(input_query, n_results, where={})
|
||||
|
||||
# Assert that the Pinecone client was called to query the database
|
||||
pinecone_client_mock.query.assert_called_once_with(
|
||||
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_reset(self, pinecone_mock):
|
||||
"""Test that the database can be reset."""
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDB()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedding_model=BaseEmbedder())
|
||||
|
||||
# Reset the database
|
||||
db.reset()
|
||||
|
||||
# Assert that the Pinecone client was called to delete the index
|
||||
pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
|
||||
|
||||
# Assert that the index is recreated
|
||||
pinecone_mock.Index.assert_called_with(db.config.index_name)
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_custom_index_name_if_it_exists(self, pinecone_mock):
|
||||
"""Tests custom index name is used if it exists"""
|
||||
pinecone_mock.list_indexes.return_value = ["custom_index_name"]
|
||||
db_config = PineconeDBConfig(index_name="custom_index_name")
|
||||
_ = PineconeDB(config=db_config)
|
||||
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.create_index.assert_not_called()
|
||||
pinecone_mock.Index.assert_called_with("custom_index_name")
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_custom_index_name_creation(self, pinecone_mock):
|
||||
"""Test custom index name is created if it doesn't exists already"""
|
||||
pinecone_mock.list_indexes.return_value = []
|
||||
db_config = PineconeDBConfig(index_name="custom_index_name")
|
||||
_ = PineconeDB(config=db_config)
|
||||
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.create_index.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_with("custom_index_name")
|
||||
|
||||
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||
def test_default_index_name_is_used(self, pinecone_mock):
|
||||
"""Test default index name is used if custom index name is not provided"""
|
||||
db_config = PineconeDBConfig(collection_name="my-collection")
|
||||
_ = PineconeDB(config=db_config)
|
||||
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.create_index.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")
|
||||
pinecone_db = mock_pinecone_db()
|
||||
# without citations
|
||||
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
|
||||
assert results == ["text_1", "text_2"]
|
||||
# with citations
|
||||
results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
|
||||
assert results == [
|
||||
("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
|
||||
("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user