From a7a61fae1d654f54cdb4895ac6f61699129ac335 Mon Sep 17 00:00:00 2001 From: Rupesh Bansal Date: Sun, 15 Oct 2023 14:24:07 +0530 Subject: [PATCH] [Feature] Pinecone Vector DB support (#723) --- embedchain/config/vectordb/pinecone.py | 18 +++ embedchain/embedchain.py | 2 - embedchain/vectordb/elasticsearch.py | 6 +- embedchain/vectordb/pineconedb.py | 180 +++++++++++++++++++++++++ pyproject.toml | 2 + tests/vectordb/test_pinecone_db.py | 106 +++++++++++++++ 6 files changed, 308 insertions(+), 6 deletions(-) create mode 100644 embedchain/config/vectordb/pinecone.py create mode 100644 embedchain/vectordb/pineconedb.py create mode 100644 tests/vectordb/test_pinecone_db.py diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py new file mode 100644 index 00000000..2e1334e3 --- /dev/null +++ b/embedchain/config/vectordb/pinecone.py @@ -0,0 +1,18 @@ +from typing import Optional + +from embedchain.config.vectordb.base import BaseVectorDbConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class PineconeDbConfig(BaseVectorDbConfig): + def __init__( + self, + collection_name: Optional[str] = None, + dir: Optional[str] = None, + dimension: Optional[int] = 1536, + metric: Optional[str] = "cosine", + ): + self.dimension = dimension + self.metric = metric + super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index ccb61511..0a52c168 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -339,7 +339,6 @@ class EmbedChain(JSONSerializable): metadatas = embeddings_data["metadatas"] ids = embeddings_data["ids"] new_doc_id = embeddings_data["doc_id"] - if existing_doc_id and existing_doc_id == new_doc_id: print("Doc content has not changed. Skipping creating chunks and embeddings") return [], [], [], 0 @@ -404,7 +403,6 @@ class EmbedChain(JSONSerializable): skip_embedding=(chunker.data_type == DataType.IMAGES), ) count_new_chunks = self.db.count() - chunks_before_addition - print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")) return list(documents), metadatas, ids, count_new_chunks def _format_result(self, results): diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index 02b45ae5..f3b2d293 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional try: from elasticsearch import Elasticsearch @@ -74,9 +74,7 @@ class ElasticsearchDB(BaseVectorDB): def _get_or_create_collection(self, name): """Note: nothing to return here. Discuss later""" - def get( - self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None - ) -> Set[str]: + def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): """ Get existing doc ids present in vector database diff --git a/embedchain/vectordb/pineconedb.py b/embedchain/vectordb/pineconedb.py new file mode 100644 index 00000000..2aa12f1b --- /dev/null +++ b/embedchain/vectordb/pineconedb.py @@ -0,0 +1,180 @@ +import copy +import os +from typing import Dict, List, Optional + +try: + import pinecone +except ImportError: + raise ImportError( + "Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`" + ) from None + +from embedchain.config.vectordb.pinecone import PineconeDbConfig +from embedchain.helper.json_serializable import register_deserializable +from embedchain.vectordb.base import BaseVectorDB + + +@register_deserializable +class PineconeDb(BaseVectorDB): + BATCH_SIZE = 100 + + """ + Pinecone as vector database + """ + + def __init__( + self, + config: Optional[PineconeDbConfig] = None, + ): + """Pinecone as vector database. + + :param config: Pinecone database config, defaults to None + :type config: PineconeDbConfig, optional + :raises ValueError: No config provided + """ + if config is None: + self.config = PineconeDbConfig() + else: + if not isinstance(config, PineconeDbConfig): + raise TypeError( + "config is not a `PineconeDbConfig` instance. " + "Please make sure the type is right and that you are passing an instance." + ) + self.config = config + self.client = self._setup_pinecone_index() + # Call parent init here because embedder is needed + super().__init__(config=self.config) + + def _initialize(self): + """ + This method is needed because `embedder` attribute needs to be set externally before it can be initialized. + """ + if not self.embedder: + raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") + + # Loads the Pinecone index or creates it if not present. + def _setup_pinecone_index(self): + pinecone.init( + api_key=os.environ.get("PINECONE_API_KEY"), + environment=os.environ.get("PINECONE_ENV"), + ) + self.index_name = self._get_index_name() + indexes = pinecone.list_indexes() + if indexes is None or self.index_name not in indexes: + pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension) + return pinecone.Index(self.index_name) + + def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None): + """ + Get existing doc ids present in vector database + + :param ids: _list of doc ids to check for existence + :type ids: List[str] + :param where: to filter data + :type where: Dict[str, any] + :return: ids + :rtype: Set[str] + """ + existing_ids = list() + if ids is not None: + for i in range(0, len(ids), 1000): + result = self.client.fetch(ids=ids[i : i + 1000]) + batch_existing_ids = list(result.get("vectors").keys()) + existing_ids.extend(batch_existing_ids) + + return {"ids": existing_ids} + + def add( + self, + embeddings: List[List[float]], + documents: List[str], + metadatas: List[object], + ids: List[str], + skip_embedding: bool, + ): + """add data in vector database + + :param documents: list of texts to add + :type documents: List[str] + :param metadatas: list of metadata associated with docs + :type metadatas: List[object] + :param ids: ids of docs + :type ids: List[str] + """ + docs = [] + if embeddings is None: + embeddings = self.embedder.embedding_fn(documents) + for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): + metadata["text"] = text + docs.append( + { + "id": id, + "values": embedding, + "metadata": copy.deepcopy(metadata), + } + ) + + for i in range(0, len(docs), self.BATCH_SIZE): + self.client.upsert(docs[i : i + self.BATCH_SIZE]) + + def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + """ + query contents from vector database based on vector similarity + + :param input_query: list of query string + :type input_query: List[str] + :param n_results: no of similar documents to fetch from database + :type n_results: int + :param where: Optional. to filter data + :type where: Dict[str, any] + :return: Database contents that are the result of the query + :rtype: List[str] + """ + if not skip_embedding: + query_vector = self.embedder.embedding_fn([input_query])[0] + else: + query_vector = input_query + contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) + embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"])) + return embeddings + + def set_collection_name(self, name: str): + """ + Set the name of the collection. A collection is an isolated space for vectors. + + :param name: Name of the collection. + :type name: str + """ + if not isinstance(name, str): + raise TypeError("Collection name must be a string") + self.config.collection_name = name + + def count(self) -> int: + """ + Count number of documents/chunks embedded in the database. + + :return: number of documents + :rtype: int + """ + return self.client.describe_index_stats()["total_vector_count"] + + def _get_or_create_db(self): + """Called during initialization""" + return self.client + + def reset(self): + """ + Resets the database. Deletes all embeddings irreversibly. + """ + # Delete all data from the database + pinecone.delete_index(self.index_name) + self._setup_pinecone_index() + + # Pinecone only allows alphanumeric characters and "-" in the index name + def _get_index_name(self) -> str: + """Get the Pinecone index for a collection + + :return: Pinecone index + :rtype: str + """ + return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-") diff --git a/pyproject.toml b/pyproject.toml index 734e9eb8..32b4a977 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ discord = { version = "^2.3.2", optional = true } slack-sdk = { version = "3.21.3", optional = true } cohere = { version = "^4.27", optional= true } docx2txt = "^0.8" +pinecone-client = "^2.2.4" unstructured = {extras = ["local-inference"], version = "^0.10.18"} pillow = { version = "10.0.1", optional = true } torchvision = { version = ">=0.15.1, !=0.15.2", optional = true } @@ -142,6 +143,7 @@ poe = ["fastapi-poe"] discord = ["discord"] slack = ["slack-sdk", "flask"] whatsapp = ["twilio", "flask"] +pinecone = ["pinecone-client"] images = ["torch", "ftfy", "regex", "pillow", "torchvision"] huggingface_hub=["huggingface_hub"] cohere = ["cohere"] diff --git a/tests/vectordb/test_pinecone_db.py b/tests/vectordb/test_pinecone_db.py new file mode 100644 index 00000000..f17252c0 --- /dev/null +++ b/tests/vectordb/test_pinecone_db.py @@ -0,0 +1,106 @@ +from unittest import mock +from unittest.mock import patch + +from embedchain import App +from embedchain.config import AppConfig +from embedchain.embedder.base import BaseEmbedder +from embedchain.vectordb.pineconedb import PineconeDb + + +class TestPineconeDb: + @patch("embedchain.vectordb.pineconedb.pinecone") + def test_init(self, pinecone_mock): + """Test that the PineconeDb can be initialized.""" + # Create a PineconeDb instance + PineconeDb() + + # Assert that the Pinecone client was initialized + pinecone_mock.init.assert_called_once() + pinecone_mock.list_indexes.assert_called_once() + pinecone_mock.Index.assert_called_once() + + @patch("embedchain.vectordb.pineconedb.pinecone") + def test_set_embedder(self, pinecone_mock): + """Test that the embedder can be set.""" + + # Set the embedder + embedder = BaseEmbedder() + + # Create a PineconeDb instance + db = PineconeDb() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=embedder) + + # Assert that the embedder was set + assert db.embedder == embedder + pinecone_mock.init.assert_called_once() + + @patch("embedchain.vectordb.pineconedb.pinecone") + def test_add_documents(self, pinecone_mock): + """Test that documents can be added to the database.""" + pinecone_client_mock = pinecone_mock.Index.return_value + + embedding_function = mock.Mock() + base_embedder = BaseEmbedder() + base_embedder.set_embedding_fn(embedding_function) + vectors = [[0, 0, 0], [1, 1, 1]] + embedding_function.return_value = vectors + # Create a PineconeDb instance + db = PineconeDb() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=base_embedder) + + # Add some documents to the database + documents = ["This is a document.", "This is another document."] + metadatas = [{}, {}] + ids = ["doc1", "doc2"] + db.add(vectors, documents, metadatas, ids, True) + + expected_pinecone_upsert_args = [ + {"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]}, + {"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]}, + ] + # Assert that the Pinecone client was called to upsert the documents + pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args) + + @patch("embedchain.vectordb.pineconedb.pinecone") + def test_query_documents(self, pinecone_mock): + """Test that documents can be queried from the database.""" + pinecone_client_mock = pinecone_mock.Index.return_value + + embedding_function = mock.Mock() + base_embedder = BaseEmbedder() + base_embedder.set_embedding_fn(embedding_function) + vectors = [[0, 0, 0]] + embedding_function.return_value = vectors + # Create a PineconeDb instance + db = PineconeDb() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=base_embedder) + + # Query the database for documents that are similar to "document" + input_query = ["document"] + n_results = 1 + db.query(input_query, n_results, where={}, skip_embedding=False) + + # Assert that the Pinecone client was called to query the database + pinecone_client_mock.query.assert_called_once_with( + vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True + ) + + @patch("embedchain.vectordb.pineconedb.pinecone") + def test_reset(self, pinecone_mock): + """Test that the database can be reset.""" + # Create a PineconeDb instance + db = PineconeDb() + app_config = AppConfig(collect_metrics=False) + App(config=app_config, db=db, embedder=BaseEmbedder()) + + # Reset the database + db.reset() + + # Assert that the Pinecone client was called to delete the index + pinecone_mock.delete_index.assert_called_once_with(db.index_name) + + # Assert that the index is recreated + pinecone_mock.Index.assert_called_with(db.index_name)