diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx
index dbf86b40..80c77ab7 100644
--- a/docs/components/vector-databases.mdx
+++ b/docs/components/vector-databases.mdx
@@ -167,7 +167,7 @@ Install pinecone related dependencies using the following command:
pip install --upgrade 'embedchain[pinecone]'
```
-In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
+In order to use Pinecone as vector database, set the environment variable `PINECONE_API_KEY` which you can find on [Pinecone dashboard](https://app.pinecone.io/).
@@ -175,23 +175,44 @@ In order to use Pinecone as vector database, set the environment variables `PINE
from embedchain import App
# load pinecone configuration from yaml file
-app = App.from_config(config_path="config.yaml")
+app = App.from_config(config_path="pod_config.yaml")
+# or
+app = App.from_config(config_path="serverless_config.yaml")
```
-```yaml config.yaml
+```yaml pod_config.yaml
vectordb:
provider: pinecone
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
+ pod_config:
+ environment: gcp-starter
+ metadata_config:
+ indexed:
+ - "url"
+ - "hash"
+```
+
+```yaml serverless_config.yaml
+vectordb:
+ provider: pinecone
+ config:
+ metric: cosine
+ vector_dimension: 1536
+ collection_name: my-pinecone-index
+ serverless_config:
+ cloud: aws
+ region: us-west-2
```
-You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
+You can find more information about Pinecone configuration [here](https://docs.pinecone.io/docs/manage-indexes#create-a-pod-based-index).
+You can also optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
## Qdrant
diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py
index a07d3dd7..f377fcba 100644
--- a/embedchain/config/vectordb/pinecone.py
+++ b/embedchain/config/vectordb/pinecone.py
@@ -1,3 +1,4 @@
+import os
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
@@ -9,14 +10,29 @@ class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
+ api_key: Optional[str] = None,
index_name: Optional[str] = None,
dir: Optional[str] = None,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
+ pod_config: Optional[dict[str, any]] = None,
+ serverless_config: Optional[dict[str, any]] = None,
**extra_params: dict[str, any],
):
self.metric = metric
+ self.api_key = api_key
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
- super().__init__(collection_name=collection_name, dir=dir)
+ if pod_config is None and serverless_config is None:
+ # If no config is provided, use the default pod spec config
+ pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
+ self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}}
+ else:
+ self.pod_config = pod_config
+ self.serverless_config = serverless_config
+
+ if self.pod_config and self.serverless_config:
+ raise ValueError("Only one of pod_config or serverless_config can be provided.")
+
+ super().__init__(collection_name=collection_name, dir=None)
diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py
index 92f3b911..e2f7c6fe 100644
--- a/embedchain/vectordb/pinecone.py
+++ b/embedchain/vectordb/pinecone.py
@@ -42,7 +42,7 @@ class PineconeDB(BaseVectorDB):
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
- self.client = self._setup_pinecone_index()
+ self._setup_pinecone_index()
# Call parent init here because embedder is needed
super().__init__(config=self.config)
@@ -57,17 +57,26 @@ class PineconeDB(BaseVectorDB):
"""
Loads the Pinecone index or creates it if not present.
"""
- pinecone.init(
- api_key=os.environ.get("PINECONE_API_KEY"),
- environment=os.environ.get("PINECONE_ENV"),
- **self.config.extra_params,
- )
- indexes = pinecone.list_indexes()
+ api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY")
+ if not api_key:
+ raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.")
+ self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params)
+ indexes = self.client.list_indexes().names()
if indexes is None or self.config.index_name not in indexes:
- pinecone.create_index(
- name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
+ if self.config.pod_config:
+ spec = pinecone.PodSpec(**self.config.pod_config)
+ elif self.config.serverless_config:
+ spec = pinecone.ServerlessSpec(**self.config.serverless_config)
+ else:
+ raise ValueError("No pod_config or serverless_config found.")
+
+ self.client.create_index(
+ name=self.config.index_name,
+ metric=self.config.metric,
+ dimension=self.config.vector_dimension,
+ spec=spec,
)
- return pinecone.Index(self.config.index_name)
+ self.pinecone_index = self.client.Index(self.config.index_name)
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
"""
@@ -85,7 +94,7 @@ class PineconeDB(BaseVectorDB):
if ids is not None:
for i in range(0, len(ids), 1000):
- result = self.client.fetch(ids=ids[i : i + 1000])
+ result = self.pinecone_index.fetch(ids=ids[i : i + 1000])
vectors = result.get("vectors")
batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids)
@@ -125,7 +134,7 @@ class PineconeDB(BaseVectorDB):
)
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
- self.client.upsert(chunk, **kwargs)
+ self.pinecone_index.upsert(chunk, **kwargs)
def query(
self,
@@ -151,15 +160,19 @@ class PineconeDB(BaseVectorDB):
"""
query_vector = self.embedder.embedding_fn([input_query])[0]
query_filter = self._generate_filter(where)
- data = self.client.query(
- vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs
+ data = self.pinecone_index.query(
+ vector=query_vector,
+ filter=query_filter,
+ top_k=n_results,
+ include_metadata=True,
+ **kwargs,
)
contexts = []
- for doc in data["matches"]:
- metadata = doc["metadata"]
- context = metadata["text"]
+ for doc in data.get("matches", []):
+ metadata = doc.get("metadata", {})
+ context = metadata.get("text")
if citations:
- metadata["score"] = doc["score"]
+ metadata["score"] = doc.get("score")
contexts.append(tuple((context, metadata)))
else:
contexts.append(context)
@@ -183,7 +196,8 @@ class PineconeDB(BaseVectorDB):
:return: number of documents
:rtype: int
"""
- return self.client.describe_index_stats()["total_vector_count"]
+ data = self.pinecone_index.describe_index_stats()
+ return data["total_vector_count"]
def _get_or_create_db(self):
"""Called during initialization"""
@@ -194,7 +208,7 @@ class PineconeDB(BaseVectorDB):
Resets the database. Deletes all embeddings irreversibly.
"""
# Delete all data from the database
- pinecone.delete_index(self.config.index_name)
+ self.client.delete_index(self.config.index_name)
self._setup_pinecone_index()
@staticmethod
@@ -213,7 +227,7 @@ class PineconeDB(BaseVectorDB):
# Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details
db_filter = self._generate_filter(where)
try:
- self.client.delete(filter=db_filter)
+ self.pinecone_index.delete(filter=db_filter)
except Exception as e:
print(f"Failed to delete from Pinecone: {e}")
return
diff --git a/poetry.lock b/poetry.lock
index 8935ff69..6422356b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1260,25 +1260,6 @@ files = [
{file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"},
]
-[[package]]
-name = "dnspython"
-version = "2.4.2"
-description = "DNS toolkit"
-optional = true
-python-versions = ">=3.8,<4.0"
-files = [
- {file = "dnspython-2.4.2-py3-none-any.whl", hash = "sha256:57c6fbaaeaaf39c891292012060beb141791735dbb4004798328fc2c467402d8"},
- {file = "dnspython-2.4.2.tar.gz", hash = "sha256:8dcfae8c7460a2f84b4072e26f1c9f4101ca20c071649cb7c34e8b6a93d58984"},
-]
-
-[package.extras]
-dnssec = ["cryptography (>=2.6,<42.0)"]
-doh = ["h2 (>=4.1.0)", "httpcore (>=0.17.3)", "httpx (>=0.24.1)"]
-doq = ["aioquic (>=0.9.20)"]
-idna = ["idna (>=2.1,<4.0)"]
-trio = ["trio (>=0.14,<0.23)"]
-wmi = ["wmi (>=1.5.1,<2.0.0)"]
-
[[package]]
name = "docx2txt"
version = "0.8"
@@ -3204,24 +3185,6 @@ files = [
{file = "lit-17.0.2.tar.gz", hash = "sha256:d6a551eab550f81023c82a260cd484d63970d2be9fd7588111208e7d2ff62212"},
]
-[[package]]
-name = "loguru"
-version = "0.7.2"
-description = "Python logging made (stupidly) simple"
-optional = true
-python-versions = ">=3.5"
-files = [
- {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
- {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
-]
-
-[package.dependencies]
-colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
-win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
-
-[package.extras]
-dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
-
[[package]]
name = "lxml"
version = "4.9.3"
@@ -4680,25 +4643,20 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa
[[package]]
name = "pinecone-client"
-version = "2.2.4"
+version = "3.0.1"
description = "Pinecone client and SDK"
optional = true
-python-versions = ">=3.8"
+python-versions = ">=3.8,<3.13"
files = [
- {file = "pinecone-client-2.2.4.tar.gz", hash = "sha256:2c1cc1d6648b2be66e944db2ffa59166a37b9164d1135ad525d9cd8b1e298168"},
- {file = "pinecone_client-2.2.4-py3-none-any.whl", hash = "sha256:5bf496c01c2f82f4e5c2dc977cc5062ecd7168b8ed90743b09afcc8c7eb242ec"},
+ {file = "pinecone_client-3.0.1-py3-none-any.whl", hash = "sha256:c9bb21c23a9088c6198c839be5538ed3f733d152d5fbeaafcc020c1b70b62c2d"},
+ {file = "pinecone_client-3.0.1.tar.gz", hash = "sha256:626a0055852c88f1462fc2e132f21d2b078f9a0a74c70b17fe07df3081c6615f"},
]
[package.dependencies]
-dnspython = ">=2.0.0"
-loguru = ">=0.5.0"
-numpy = ">=1.22.0"
-python-dateutil = ">=2.5.3"
-pyyaml = ">=5.4"
-requests = ">=2.19.0"
+certifi = ">=2019.11.17"
tqdm = ">=4.64.1"
typing-extensions = ">=3.7.4"
-urllib3 = ">=1.21.1"
+urllib3 = ">=1.26.0"
[package.extras]
grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"]
@@ -8019,20 +7977,6 @@ files = [
[package.extras]
test = ["pytest (>=6.0.0)", "setuptools (>=65)"]
-[[package]]
-name = "win32-setctime"
-version = "1.1.0"
-description = "A small Python utility to set file creation time on Windows"
-optional = true
-python-versions = ">=3.5"
-files = [
- {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
- {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
-]
-
-[package.extras]
-dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
-
[[package]]
name = "wrapt"
version = "1.15.0"
@@ -8316,4 +8260,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
-content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7"
+content-hash = "a16addd3362ae70c79b15677c6815f708677f11f636093f4e1f5084ba44b5a36"
diff --git a/pyproject.toml b/pyproject.toml
index 4af1da0c..7c015b27 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -123,7 +123,7 @@ cohere = { version = "^4.27", optional = true }
together = { version = "^0.2.8", optional = true }
weaviate-client = { version = "^3.24.1", optional = true }
docx2txt = { version = "^0.8", optional = true }
-pinecone-client = { version = "^2.2.4", optional = true }
+pinecone-client = { version = "^3.0.0", optional = true }
qdrant-client = { version = "1.6.3", optional = true }
unstructured = {extras = ["local-inference", "all-docs"], version = "^0.10.18", optional = true}
huggingface_hub = { version = "^0.17.3", optional = true }
diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py
index 96869d2b..49319892 100644
--- a/tests/vectordb/test_pinecone.py
+++ b/tests/vectordb/test_pinecone.py
@@ -1,139 +1,225 @@
-from unittest import mock
-from unittest.mock import patch
+import pytest
-from embedchain import App
-from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
-from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pinecone import PineconeDB
-class TestPinecone:
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_init(self, pinecone_mock):
- """Test that the PineconeDB can be initialized."""
- # Create a PineconeDB instance
- PineconeDB()
+@pytest.fixture
+def pinecone_pod_config():
+ return PineconeDBConfig(
+ collection_name="test_collection",
+ api_key="test_api_key",
+ vector_dimension=3,
+ pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
+ )
- # Assert that the Pinecone client was initialized
- pinecone_mock.init.assert_called_once()
- pinecone_mock.list_indexes.assert_called_once()
- pinecone_mock.Index.assert_called_once()
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_set_embedder(self, pinecone_mock):
- """Test that the embedder can be set."""
+@pytest.fixture
+def pinecone_serverless_config():
+ return PineconeDBConfig(
+ collection_name="test_collection",
+ api_key="test_api_key",
+ vector_dimension=3,
+ serverless_config={
+ "cloud": "test_cloud",
+ "region": "test_region",
+ },
+ )
- # Set the embedder
- embedder = BaseEmbedder()
- # Create a PineconeDB instance
+def test_pinecone_init_without_config(monkeypatch):
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
+ pinecone_db = PineconeDB()
+
+ assert isinstance(pinecone_db, PineconeDB)
+ assert isinstance(pinecone_db.config, PineconeDBConfig)
+ assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}}
+ monkeypatch.delenv("PINECONE_API_KEY")
+
+
+def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
+ pinecone_db = PineconeDB(config=pinecone_pod_config)
+
+ assert isinstance(pinecone_db, PineconeDB)
+ assert isinstance(pinecone_db.config, PineconeDBConfig)
+
+ assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config
+
+ pinecone_db = PineconeDB(config=pinecone_pod_config)
+
+ assert isinstance(pinecone_db, PineconeDB)
+ assert isinstance(pinecone_db.config, PineconeDBConfig)
+
+ assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config
+
+
+class MockListIndexes:
+ def names(self):
+ return ["test_collection"]
+
+
+class MockPineconeIndex:
+ db = []
+
+ def __init__(*args, **kwargs):
+ pass
+
+ def upsert(self, chunk, **kwargs):
+ self.db.extend([c for c in chunk])
+ return
+
+ def delete(self, *args, **kwargs):
+ pass
+
+ def query(self, *args, **kwargs):
+ return {
+ "matches": [
+ {
+ "metadata": {
+ "key": "value",
+ "text": "text_1",
+ },
+ "score": 0.1,
+ },
+ {
+ "metadata": {
+ "key": "value",
+ "text": "text_2",
+ },
+ "score": 0.2,
+ },
+ ]
+ }
+
+ def fetch(self, *args, **kwargs):
+ return {
+ "vectors": {
+ "key_1": {
+ "metadata": {
+ "source": "1",
+ }
+ },
+ "key_2": {
+ "metadata": {
+ "source": "2",
+ }
+ },
+ }
+ }
+
+ def describe_index_stats(self, *args, **kwargs):
+ return {"total_vector_count": len(self.db)}
+
+
+class MockPineconeClient:
+ def __init__(*args, **kwargs):
+ pass
+
+ def list_indexes(self):
+ return MockListIndexes()
+
+ def create_index(self, *args, **kwargs):
+ pass
+
+ def Index(self, *args, **kwargs):
+ return MockPineconeIndex()
+
+ def delete_index(self, *args, **kwargs):
+ pass
+
+
+class MockPinecone:
+ def __init__(*args, **kwargs):
+ pass
+
+ def Pinecone(*args, **kwargs):
+ return MockPineconeClient()
+
+ def PodSpec(*args, **kwargs):
+ pass
+
+ def ServerlessSpec(*args, **kwargs):
+ pass
+
+
+class MockEmbedder:
+ def embedding_fn(self, documents):
+ return [[1, 1, 1] for d in documents]
+
+
+def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
+ monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone)
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
+ pinecone_db = PineconeDB(config=pinecone_pod_config)
+ pinecone_db._setup_pinecone_index()
+
+ assert pinecone_db.client is not None
+ assert pinecone_db.config.index_name == "test-collection-3"
+ assert pinecone_db.client.list_indexes().names() == ["test_collection"]
+ assert pinecone_db.pinecone_index is not None
+
+ pinecone_db = PineconeDB(config=pinecone_serverless_config)
+ pinecone_db._setup_pinecone_index()
+
+ assert pinecone_db.client is not None
+ assert pinecone_db.config.index_name == "test-collection-3"
+ assert pinecone_db.client.list_indexes().names() == ["test_collection"]
+ assert pinecone_db.pinecone_index is not None
+
+
+def test_get(monkeypatch):
+ def mock_pinecone_db():
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=embedder)
+ db.pinecone_index = MockPineconeIndex()
+ return db
- # Assert that the embedder was set
- assert db.embedder == embedder
- pinecone_mock.init.assert_called_once()
+ pinecone_db = mock_pinecone_db()
+ ids = pinecone_db.get(["key_1", "key_2"])
+ assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]}
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_add_documents(self, pinecone_mock):
- """Test that documents can be added to the database."""
- pinecone_client_mock = pinecone_mock.Index.return_value
- embedding_function = mock.Mock()
- base_embedder = BaseEmbedder()
- base_embedder.set_embedding_fn(embedding_function)
- embedding_function.return_value = [[0, 0, 0], [1, 1, 1]]
-
- # Create a PineconeDb instance
+def test_add(monkeypatch):
+ def mock_pinecone_db():
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=base_embedder)
+ db.pinecone_index = MockPineconeIndex()
+ db._set_embedder(MockEmbedder())
+ return db
- # Add some documents to the database
- documents = ["This is a document.", "This is another document."]
- metadatas = [{}, {}]
- ids = ["doc1", "doc2"]
- db.add(documents, metadatas, ids)
+ pinecone_db = mock_pinecone_db()
+ pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"])
+ assert pinecone_db.count() == 2
- expected_pinecone_upsert_args = [
- {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
- {"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
- ]
- # Assert that the Pinecone client was called to upsert the documents
- pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
+ pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"])
+ assert pinecone_db.count() == 4
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_query_documents(self, pinecone_mock):
- """Test that documents can be queried from the database."""
- pinecone_client_mock = pinecone_mock.Index.return_value
- embedding_function = mock.Mock()
- base_embedder = BaseEmbedder()
- base_embedder.set_embedding_fn(embedding_function)
- vectors = [[0, 0, 0]]
- embedding_function.return_value = vectors
- # Create a PineconeDB instance
+def test_query(monkeypatch):
+ def mock_pinecone_db():
+ monkeypatch.setenv("PINECONE_API_KEY", "test_api_key")
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
+ monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
db = PineconeDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=base_embedder)
+ db.pinecone_index = MockPineconeIndex()
+ db._set_embedder(MockEmbedder())
+ return db
- # Query the database for documents that are similar to "document"
- input_query = ["document"]
- n_results = 1
- db.query(input_query, n_results, where={})
-
- # Assert that the Pinecone client was called to query the database
- pinecone_client_mock.query.assert_called_once_with(
- vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
- )
-
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_reset(self, pinecone_mock):
- """Test that the database can be reset."""
- # Create a PineconeDb instance
- db = PineconeDB()
- app_config = AppConfig(collect_metrics=False)
- App(config=app_config, db=db, embedding_model=BaseEmbedder())
-
- # Reset the database
- db.reset()
-
- # Assert that the Pinecone client was called to delete the index
- pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
-
- # Assert that the index is recreated
- pinecone_mock.Index.assert_called_with(db.config.index_name)
-
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_custom_index_name_if_it_exists(self, pinecone_mock):
- """Tests custom index name is used if it exists"""
- pinecone_mock.list_indexes.return_value = ["custom_index_name"]
- db_config = PineconeDBConfig(index_name="custom_index_name")
- _ = PineconeDB(config=db_config)
-
- pinecone_mock.list_indexes.assert_called_once()
- pinecone_mock.create_index.assert_not_called()
- pinecone_mock.Index.assert_called_with("custom_index_name")
-
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_custom_index_name_creation(self, pinecone_mock):
- """Test custom index name is created if it doesn't exists already"""
- pinecone_mock.list_indexes.return_value = []
- db_config = PineconeDBConfig(index_name="custom_index_name")
- _ = PineconeDB(config=db_config)
-
- pinecone_mock.list_indexes.assert_called_once()
- pinecone_mock.create_index.assert_called_once()
- pinecone_mock.Index.assert_called_with("custom_index_name")
-
- @patch("embedchain.vectordb.pinecone.pinecone")
- def test_default_index_name_is_used(self, pinecone_mock):
- """Test default index name is used if custom index name is not provided"""
- db_config = PineconeDBConfig(collection_name="my-collection")
- _ = PineconeDB(config=db_config)
-
- pinecone_mock.list_indexes.assert_called_once()
- pinecone_mock.create_index.assert_called_once()
- pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")
+ pinecone_db = mock_pinecone_db()
+ # without citations
+ results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={})
+ assert results == ["text_1", "text_2"]
+ # with citations
+ results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True)
+ assert results == [
+ ("text_1", {"key": "value", "text": "text_1", "score": 0.1}),
+ ("text_2", {"key": "value", "text": "text_2", "score": 0.2}),
+ ]