From e75c05112e899d3986d991eae30d9f8003f0d2d9 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Fri, 26 Jan 2024 09:08:37 +0530 Subject: [PATCH] [Improvement] update pinecone client v3 (#1200) Co-authored-by: Deven Patel --- docs/components/vector-databases.mdx | 29 ++- embedchain/config/vectordb/pinecone.py | 18 +- embedchain/vectordb/pinecone.py | 56 +++-- poetry.lock | 70 +----- pyproject.toml | 2 +- tests/vectordb/test_pinecone.py | 324 ++++++++++++++++--------- 6 files changed, 290 insertions(+), 209 deletions(-) diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx index dbf86b40..80c77ab7 100644 --- a/docs/components/vector-databases.mdx +++ b/docs/components/vector-databases.mdx @@ -167,7 +167,7 @@ Install pinecone related dependencies using the following command: pip install --upgrade 'embedchain[pinecone]' ``` -In order to use Pinecone as vector database, set the environment variables `PINECONE_API_KEY` and `PINECONE_ENV` which you can find on [Pinecone dashboard](https://app.pinecone.io/). +In order to use Pinecone as vector database, set the environment variable `PINECONE_API_KEY` which you can find on [Pinecone dashboard](https://app.pinecone.io/). @@ -175,23 +175,44 @@ In order to use Pinecone as vector database, set the environment variables `PINE from embedchain import App # load pinecone configuration from yaml file -app = App.from_config(config_path="config.yaml") +app = App.from_config(config_path="pod_config.yaml") +# or +app = App.from_config(config_path="serverless_config.yaml") ``` -```yaml config.yaml +```yaml pod_config.yaml vectordb: provider: pinecone config: metric: cosine vector_dimension: 1536 collection_name: my-pinecone-index + pod_config: + environment: gcp-starter + metadata_config: + indexed: + - "url" + - "hash" +``` + +```yaml serverless_config.yaml +vectordb: + provider: pinecone + config: + metric: cosine + vector_dimension: 1536 + collection_name: my-pinecone-index + serverless_config: + cloud: aws + region: us-west-2 ```
-You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`. +You can find more information about Pinecone configuration [here](https://docs.pinecone.io/docs/manage-indexes#create-a-pod-based-index). +You can also optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`. ## Qdrant diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index a07d3dd7..f377fcba 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -1,3 +1,4 @@ +import os from typing import Optional from embedchain.config.vectordb.base import BaseVectorDbConfig @@ -9,14 +10,29 @@ class PineconeDBConfig(BaseVectorDbConfig): def __init__( self, collection_name: Optional[str] = None, + api_key: Optional[str] = None, index_name: Optional[str] = None, dir: Optional[str] = None, vector_dimension: int = 1536, metric: Optional[str] = "cosine", + pod_config: Optional[dict[str, any]] = None, + serverless_config: Optional[dict[str, any]] = None, **extra_params: dict[str, any], ): self.metric = metric + self.api_key = api_key self.vector_dimension = vector_dimension self.extra_params = extra_params self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-") - super().__init__(collection_name=collection_name, dir=dir) + if pod_config is None and serverless_config is None: + # If no config is provided, use the default pod spec config + pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter") + self.pod_config = {"environment": pod_environment, "metadata_config": {"indexed": ["*"]}} + else: + self.pod_config = pod_config + self.serverless_config = serverless_config + + if self.pod_config and self.serverless_config: + raise ValueError("Only one of pod_config or serverless_config can be provided.") + + super().__init__(collection_name=collection_name, dir=None) diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 92f3b911..e2f7c6fe 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -42,7 +42,7 @@ class PineconeDB(BaseVectorDB): "Please make sure the type is right and that you are passing an instance." ) self.config = config - self.client = self._setup_pinecone_index() + self._setup_pinecone_index() # Call parent init here because embedder is needed super().__init__(config=self.config) @@ -57,17 +57,26 @@ class PineconeDB(BaseVectorDB): """ Loads the Pinecone index or creates it if not present. """ - pinecone.init( - api_key=os.environ.get("PINECONE_API_KEY"), - environment=os.environ.get("PINECONE_ENV"), - **self.config.extra_params, - ) - indexes = pinecone.list_indexes() + api_key = self.config.api_key or os.environ.get("PINECONE_API_KEY") + if not api_key: + raise ValueError("Please set the PINECONE_API_KEY environment variable or pass it in config.") + self.client = pinecone.Pinecone(api_key=api_key, **self.config.extra_params) + indexes = self.client.list_indexes().names() if indexes is None or self.config.index_name not in indexes: - pinecone.create_index( - name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension + if self.config.pod_config: + spec = pinecone.PodSpec(**self.config.pod_config) + elif self.config.serverless_config: + spec = pinecone.ServerlessSpec(**self.config.serverless_config) + else: + raise ValueError("No pod_config or serverless_config found.") + + self.client.create_index( + name=self.config.index_name, + metric=self.config.metric, + dimension=self.config.vector_dimension, + spec=spec, ) - return pinecone.Index(self.config.index_name) + self.pinecone_index = self.client.Index(self.config.index_name) def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ @@ -85,7 +94,7 @@ class PineconeDB(BaseVectorDB): if ids is not None: for i in range(0, len(ids), 1000): - result = self.client.fetch(ids=ids[i : i + 1000]) + result = self.pinecone_index.fetch(ids=ids[i : i + 1000]) vectors = result.get("vectors") batch_existing_ids = list(vectors.keys()) existing_ids.extend(batch_existing_ids) @@ -125,7 +134,7 @@ class PineconeDB(BaseVectorDB): ) for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"): - self.client.upsert(chunk, **kwargs) + self.pinecone_index.upsert(chunk, **kwargs) def query( self, @@ -151,15 +160,19 @@ class PineconeDB(BaseVectorDB): """ query_vector = self.embedder.embedding_fn([input_query])[0] query_filter = self._generate_filter(where) - data = self.client.query( - vector=query_vector, filter=query_filter, top_k=n_results, include_metadata=True, **kwargs + data = self.pinecone_index.query( + vector=query_vector, + filter=query_filter, + top_k=n_results, + include_metadata=True, + **kwargs, ) contexts = [] - for doc in data["matches"]: - metadata = doc["metadata"] - context = metadata["text"] + for doc in data.get("matches", []): + metadata = doc.get("metadata", {}) + context = metadata.get("text") if citations: - metadata["score"] = doc["score"] + metadata["score"] = doc.get("score") contexts.append(tuple((context, metadata))) else: contexts.append(context) @@ -183,7 +196,8 @@ class PineconeDB(BaseVectorDB): :return: number of documents :rtype: int """ - return self.client.describe_index_stats()["total_vector_count"] + data = self.pinecone_index.describe_index_stats() + return data["total_vector_count"] def _get_or_create_db(self): """Called during initialization""" @@ -194,7 +208,7 @@ class PineconeDB(BaseVectorDB): Resets the database. Deletes all embeddings irreversibly. """ # Delete all data from the database - pinecone.delete_index(self.config.index_name) + self.client.delete_index(self.config.index_name) self._setup_pinecone_index() @staticmethod @@ -213,7 +227,7 @@ class PineconeDB(BaseVectorDB): # Follow `https://docs.pinecone.io/docs/metadata-filtering#deleting-vectors-by-metadata-filter` for more details db_filter = self._generate_filter(where) try: - self.client.delete(filter=db_filter) + self.pinecone_index.delete(filter=db_filter) except Exception as e: print(f"Failed to delete from Pinecone: {e}") return diff --git a/poetry.lock b/poetry.lock index 8935ff69..6422356b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1260,25 +1260,6 @@ files = [ {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, ] -[[package]] -name = "dnspython" -version = "2.4.2" -description = "DNS toolkit" -optional = true -python-versions = ">=3.8,<4.0" -files = [ - {file = "dnspython-2.4.2-py3-none-any.whl", hash = "sha256:57c6fbaaeaaf39c891292012060beb141791735dbb4004798328fc2c467402d8"}, - {file = "dnspython-2.4.2.tar.gz", hash = "sha256:8dcfae8c7460a2f84b4072e26f1c9f4101ca20c071649cb7c34e8b6a93d58984"}, -] - -[package.extras] -dnssec = ["cryptography (>=2.6,<42.0)"] -doh = ["h2 (>=4.1.0)", "httpcore (>=0.17.3)", "httpx (>=0.24.1)"] -doq = ["aioquic (>=0.9.20)"] -idna = ["idna (>=2.1,<4.0)"] -trio = ["trio (>=0.14,<0.23)"] -wmi = ["wmi (>=1.5.1,<2.0.0)"] - [[package]] name = "docx2txt" version = "0.8" @@ -3204,24 +3185,6 @@ files = [ {file = "lit-17.0.2.tar.gz", hash = "sha256:d6a551eab550f81023c82a260cd484d63970d2be9fd7588111208e7d2ff62212"}, ] -[[package]] -name = "loguru" -version = "0.7.2" -description = "Python logging made (stupidly) simple" -optional = true -python-versions = ">=3.5" -files = [ - {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"}, - {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"}, -] - -[package.dependencies] -colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""} -win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} - -[package.extras] -dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"] - [[package]] name = "lxml" version = "4.9.3" @@ -4680,25 +4643,20 @@ tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "pa [[package]] name = "pinecone-client" -version = "2.2.4" +version = "3.0.1" description = "Pinecone client and SDK" optional = true -python-versions = ">=3.8" +python-versions = ">=3.8,<3.13" files = [ - {file = "pinecone-client-2.2.4.tar.gz", hash = "sha256:2c1cc1d6648b2be66e944db2ffa59166a37b9164d1135ad525d9cd8b1e298168"}, - {file = "pinecone_client-2.2.4-py3-none-any.whl", hash = "sha256:5bf496c01c2f82f4e5c2dc977cc5062ecd7168b8ed90743b09afcc8c7eb242ec"}, + {file = "pinecone_client-3.0.1-py3-none-any.whl", hash = "sha256:c9bb21c23a9088c6198c839be5538ed3f733d152d5fbeaafcc020c1b70b62c2d"}, + {file = "pinecone_client-3.0.1.tar.gz", hash = "sha256:626a0055852c88f1462fc2e132f21d2b078f9a0a74c70b17fe07df3081c6615f"}, ] [package.dependencies] -dnspython = ">=2.0.0" -loguru = ">=0.5.0" -numpy = ">=1.22.0" -python-dateutil = ">=2.5.3" -pyyaml = ">=5.4" -requests = ">=2.19.0" +certifi = ">=2019.11.17" tqdm = ">=4.64.1" typing-extensions = ">=3.7.4" -urllib3 = ">=1.21.1" +urllib3 = ">=1.26.0" [package.extras] grpc = ["googleapis-common-protos (>=1.53.0)", "grpc-gateway-protoc-gen-openapiv2 (==0.1.0)", "grpcio (>=1.44.0)", "lz4 (>=3.1.3)", "protobuf (>=3.20.0,<3.21.0)"] @@ -8019,20 +7977,6 @@ files = [ [package.extras] test = ["pytest (>=6.0.0)", "setuptools (>=65)"] -[[package]] -name = "win32-setctime" -version = "1.1.0" -description = "A small Python utility to set file creation time on Windows" -optional = true -python-versions = ">=3.5" -files = [ - {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"}, - {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"}, -] - -[package.extras] -dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] - [[package]] name = "wrapt" version = "1.15.0" @@ -8316,4 +8260,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7" +content-hash = "a16addd3362ae70c79b15677c6815f708677f11f636093f4e1f5084ba44b5a36" diff --git a/pyproject.toml b/pyproject.toml index 4af1da0c..7c015b27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,7 @@ cohere = { version = "^4.27", optional = true } together = { version = "^0.2.8", optional = true } weaviate-client = { version = "^3.24.1", optional = true } docx2txt = { version = "^0.8", optional = true } -pinecone-client = { version = "^2.2.4", optional = true } +pinecone-client = { version = "^3.0.0", optional = true } qdrant-client = { version = "1.6.3", optional = true } unstructured = {extras = ["local-inference", "all-docs"], version = "^0.10.18", optional = true} huggingface_hub = { version = "^0.17.3", optional = true } diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py index 96869d2b..49319892 100644 --- a/tests/vectordb/test_pinecone.py +++ b/tests/vectordb/test_pinecone.py @@ -1,139 +1,225 @@ -from unittest import mock -from unittest.mock import patch +import pytest -from embedchain import App -from embedchain.config import AppConfig from embedchain.config.vectordb.pinecone import PineconeDBConfig -from embedchain.embedder.base import BaseEmbedder from embedchain.vectordb.pinecone import PineconeDB -class TestPinecone: - @patch("embedchain.vectordb.pinecone.pinecone") - def test_init(self, pinecone_mock): - """Test that the PineconeDB can be initialized.""" - # Create a PineconeDB instance - PineconeDB() +@pytest.fixture +def pinecone_pod_config(): + return PineconeDBConfig( + collection_name="test_collection", + api_key="test_api_key", + vector_dimension=3, + pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}}, + ) - # Assert that the Pinecone client was initialized - pinecone_mock.init.assert_called_once() - pinecone_mock.list_indexes.assert_called_once() - pinecone_mock.Index.assert_called_once() - @patch("embedchain.vectordb.pinecone.pinecone") - def test_set_embedder(self, pinecone_mock): - """Test that the embedder can be set.""" +@pytest.fixture +def pinecone_serverless_config(): + return PineconeDBConfig( + collection_name="test_collection", + api_key="test_api_key", + vector_dimension=3, + serverless_config={ + "cloud": "test_cloud", + "region": "test_region", + }, + ) - # Set the embedder - embedder = BaseEmbedder() - # Create a PineconeDB instance +def test_pinecone_init_without_config(monkeypatch): + monkeypatch.setenv("PINECONE_API_KEY", "test_api_key") + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x) + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x) + pinecone_db = PineconeDB() + + assert isinstance(pinecone_db, PineconeDB) + assert isinstance(pinecone_db.config, PineconeDBConfig) + assert pinecone_db.config.pod_config == {"environment": "gcp-starter", "metadata_config": {"indexed": ["*"]}} + monkeypatch.delenv("PINECONE_API_KEY") + + +def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch): + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x) + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x) + pinecone_db = PineconeDB(config=pinecone_pod_config) + + assert isinstance(pinecone_db, PineconeDB) + assert isinstance(pinecone_db.config, PineconeDBConfig) + + assert pinecone_db.config.pod_config == pinecone_pod_config.pod_config + + pinecone_db = PineconeDB(config=pinecone_pod_config) + + assert isinstance(pinecone_db, PineconeDB) + assert isinstance(pinecone_db.config, PineconeDBConfig) + + assert pinecone_db.config.serverless_config == pinecone_pod_config.serverless_config + + +class MockListIndexes: + def names(self): + return ["test_collection"] + + +class MockPineconeIndex: + db = [] + + def __init__(*args, **kwargs): + pass + + def upsert(self, chunk, **kwargs): + self.db.extend([c for c in chunk]) + return + + def delete(self, *args, **kwargs): + pass + + def query(self, *args, **kwargs): + return { + "matches": [ + { + "metadata": { + "key": "value", + "text": "text_1", + }, + "score": 0.1, + }, + { + "metadata": { + "key": "value", + "text": "text_2", + }, + "score": 0.2, + }, + ] + } + + def fetch(self, *args, **kwargs): + return { + "vectors": { + "key_1": { + "metadata": { + "source": "1", + } + }, + "key_2": { + "metadata": { + "source": "2", + } + }, + } + } + + def describe_index_stats(self, *args, **kwargs): + return {"total_vector_count": len(self.db)} + + +class MockPineconeClient: + def __init__(*args, **kwargs): + pass + + def list_indexes(self): + return MockListIndexes() + + def create_index(self, *args, **kwargs): + pass + + def Index(self, *args, **kwargs): + return MockPineconeIndex() + + def delete_index(self, *args, **kwargs): + pass + + +class MockPinecone: + def __init__(*args, **kwargs): + pass + + def Pinecone(*args, **kwargs): + return MockPineconeClient() + + def PodSpec(*args, **kwargs): + pass + + def ServerlessSpec(*args, **kwargs): + pass + + +class MockEmbedder: + def embedding_fn(self, documents): + return [[1, 1, 1] for d in documents] + + +def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, monkeypatch): + monkeypatch.setattr("embedchain.vectordb.pinecone.pinecone", MockPinecone) + monkeypatch.setenv("PINECONE_API_KEY", "test_api_key") + pinecone_db = PineconeDB(config=pinecone_pod_config) + pinecone_db._setup_pinecone_index() + + assert pinecone_db.client is not None + assert pinecone_db.config.index_name == "test-collection-3" + assert pinecone_db.client.list_indexes().names() == ["test_collection"] + assert pinecone_db.pinecone_index is not None + + pinecone_db = PineconeDB(config=pinecone_serverless_config) + pinecone_db._setup_pinecone_index() + + assert pinecone_db.client is not None + assert pinecone_db.config.index_name == "test-collection-3" + assert pinecone_db.client.list_indexes().names() == ["test_collection"] + assert pinecone_db.pinecone_index is not None + + +def test_get(monkeypatch): + def mock_pinecone_db(): + monkeypatch.setenv("PINECONE_API_KEY", "test_api_key") + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x) + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x) db = PineconeDB() - app_config = AppConfig(collect_metrics=False) - App(config=app_config, db=db, embedding_model=embedder) + db.pinecone_index = MockPineconeIndex() + return db - # Assert that the embedder was set - assert db.embedder == embedder - pinecone_mock.init.assert_called_once() + pinecone_db = mock_pinecone_db() + ids = pinecone_db.get(["key_1", "key_2"]) + assert ids == {"ids": ["key_1", "key_2"], "metadatas": [{"source": "1"}, {"source": "2"}]} - @patch("embedchain.vectordb.pinecone.pinecone") - def test_add_documents(self, pinecone_mock): - """Test that documents can be added to the database.""" - pinecone_client_mock = pinecone_mock.Index.return_value - embedding_function = mock.Mock() - base_embedder = BaseEmbedder() - base_embedder.set_embedding_fn(embedding_function) - embedding_function.return_value = [[0, 0, 0], [1, 1, 1]] - - # Create a PineconeDb instance +def test_add(monkeypatch): + def mock_pinecone_db(): + monkeypatch.setenv("PINECONE_API_KEY", "test_api_key") + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x) + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x) db = PineconeDB() - app_config = AppConfig(collect_metrics=False) - App(config=app_config, db=db, embedding_model=base_embedder) + db.pinecone_index = MockPineconeIndex() + db._set_embedder(MockEmbedder()) + return db - # Add some documents to the database - documents = ["This is a document.", "This is another document."] - metadatas = [{}, {}] - ids = ["doc1", "doc2"] - db.add(documents, metadatas, ids) + pinecone_db = mock_pinecone_db() + pinecone_db.add(["text_1", "text_2"], [{"key_1": "value_1"}, {"key_2": "value_2"}], ["key_1", "key_2"]) + assert pinecone_db.count() == 2 - expected_pinecone_upsert_args = [ - {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}}, - {"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}}, - ] - # Assert that the Pinecone client was called to upsert the documents - pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args)) + pinecone_db.add(["text_3", "text_4"], [{"key_3": "value_3"}, {"key_4": "value_4"}], ["key_3", "key_4"]) + assert pinecone_db.count() == 4 - @patch("embedchain.vectordb.pinecone.pinecone") - def test_query_documents(self, pinecone_mock): - """Test that documents can be queried from the database.""" - pinecone_client_mock = pinecone_mock.Index.return_value - embedding_function = mock.Mock() - base_embedder = BaseEmbedder() - base_embedder.set_embedding_fn(embedding_function) - vectors = [[0, 0, 0]] - embedding_function.return_value = vectors - # Create a PineconeDB instance +def test_query(monkeypatch): + def mock_pinecone_db(): + monkeypatch.setenv("PINECONE_API_KEY", "test_api_key") + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x) + monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x) db = PineconeDB() - app_config = AppConfig(collect_metrics=False) - App(config=app_config, db=db, embedding_model=base_embedder) + db.pinecone_index = MockPineconeIndex() + db._set_embedder(MockEmbedder()) + return db - # Query the database for documents that are similar to "document" - input_query = ["document"] - n_results = 1 - db.query(input_query, n_results, where={}) - - # Assert that the Pinecone client was called to query the database - pinecone_client_mock.query.assert_called_once_with( - vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True - ) - - @patch("embedchain.vectordb.pinecone.pinecone") - def test_reset(self, pinecone_mock): - """Test that the database can be reset.""" - # Create a PineconeDb instance - db = PineconeDB() - app_config = AppConfig(collect_metrics=False) - App(config=app_config, db=db, embedding_model=BaseEmbedder()) - - # Reset the database - db.reset() - - # Assert that the Pinecone client was called to delete the index - pinecone_mock.delete_index.assert_called_once_with(db.config.index_name) - - # Assert that the index is recreated - pinecone_mock.Index.assert_called_with(db.config.index_name) - - @patch("embedchain.vectordb.pinecone.pinecone") - def test_custom_index_name_if_it_exists(self, pinecone_mock): - """Tests custom index name is used if it exists""" - pinecone_mock.list_indexes.return_value = ["custom_index_name"] - db_config = PineconeDBConfig(index_name="custom_index_name") - _ = PineconeDB(config=db_config) - - pinecone_mock.list_indexes.assert_called_once() - pinecone_mock.create_index.assert_not_called() - pinecone_mock.Index.assert_called_with("custom_index_name") - - @patch("embedchain.vectordb.pinecone.pinecone") - def test_custom_index_name_creation(self, pinecone_mock): - """Test custom index name is created if it doesn't exists already""" - pinecone_mock.list_indexes.return_value = [] - db_config = PineconeDBConfig(index_name="custom_index_name") - _ = PineconeDB(config=db_config) - - pinecone_mock.list_indexes.assert_called_once() - pinecone_mock.create_index.assert_called_once() - pinecone_mock.Index.assert_called_with("custom_index_name") - - @patch("embedchain.vectordb.pinecone.pinecone") - def test_default_index_name_is_used(self, pinecone_mock): - """Test default index name is used if custom index name is not provided""" - db_config = PineconeDBConfig(collection_name="my-collection") - _ = PineconeDB(config=db_config) - - pinecone_mock.list_indexes.assert_called_once() - pinecone_mock.create_index.assert_called_once() - pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}") + pinecone_db = mock_pinecone_db() + # without citations + results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}) + assert results == ["text_1", "text_2"] + # with citations + results = pinecone_db.query(["text_1", "text_2"], n_results=2, where={}, citations=True) + assert results == [ + ("text_1", {"key": "value", "text": "text_1", "score": 0.1}), + ("text_2", {"key": "value", "text": "text_2", "score": 0.2}), + ]