diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx index b039cfb1..dbf86b40 100644 --- a/docs/components/vector-databases.mdx +++ b/docs/components/vector-databases.mdx @@ -189,6 +189,11 @@ vectordb: +
+ +You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`. + + ## Qdrant In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/). diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index efb98c79..a07d3dd7 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -9,6 +9,7 @@ class PineconeDBConfig(BaseVectorDbConfig): def __init__( self, collection_name: Optional[str] = None, + index_name: Optional[str] = None, dir: Optional[str] = None, vector_dimension: int = 1536, metric: Optional[str] = "cosine", @@ -17,4 +18,5 @@ class PineconeDBConfig(BaseVectorDbConfig): self.metric = metric self.vector_dimension = vector_dimension self.extra_params = extra_params + self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-") super().__init__(collection_name=collection_name, dir=dir) diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index e0e21558..92f3b911 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -53,20 +53,21 @@ class PineconeDB(BaseVectorDB): if not self.embedder: raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") - # Loads the Pinecone index or creates it if not present. def _setup_pinecone_index(self): + """ + Loads the Pinecone index or creates it if not present. + """ pinecone.init( api_key=os.environ.get("PINECONE_API_KEY"), environment=os.environ.get("PINECONE_ENV"), **self.config.extra_params, ) - self.index_name = self._get_index_name() indexes = pinecone.list_indexes() - if indexes is None or self.index_name not in indexes: + if indexes is None or self.config.index_name not in indexes: pinecone.create_index( - name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension + name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension ) - return pinecone.Index(self.index_name) + return pinecone.Index(self.config.index_name) def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): """ @@ -193,18 +194,9 @@ class PineconeDB(BaseVectorDB): Resets the database. Deletes all embeddings irreversibly. """ # Delete all data from the database - pinecone.delete_index(self.index_name) + pinecone.delete_index(self.config.index_name) self._setup_pinecone_index() - # Pinecone only allows alphanumeric characters and "-" in the index name - def _get_index_name(self) -> str: - """Get the Pinecone index for a collection - - :return: Pinecone index - :rtype: str - """ - return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-") - @staticmethod def _generate_filter(where: dict): query = {} diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py index 8cb08788..96869d2b 100644 --- a/tests/vectordb/test_pinecone.py +++ b/tests/vectordb/test_pinecone.py @@ -3,6 +3,7 @@ from unittest.mock import patch from embedchain import App from embedchain.config import AppConfig +from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.embedder.base import BaseEmbedder from embedchain.vectordb.pinecone import PineconeDB @@ -100,7 +101,39 @@ class TestPinecone: db.reset() # Assert that the Pinecone client was called to delete the index - pinecone_mock.delete_index.assert_called_once_with(db.index_name) + pinecone_mock.delete_index.assert_called_once_with(db.config.index_name) # Assert that the index is recreated - pinecone_mock.Index.assert_called_with(db.index_name) + pinecone_mock.Index.assert_called_with(db.config.index_name) + + @patch("embedchain.vectordb.pinecone.pinecone") + def test_custom_index_name_if_it_exists(self, pinecone_mock): + """Tests custom index name is used if it exists""" + pinecone_mock.list_indexes.return_value = ["custom_index_name"] + db_config = PineconeDBConfig(index_name="custom_index_name") + _ = PineconeDB(config=db_config) + + pinecone_mock.list_indexes.assert_called_once() + pinecone_mock.create_index.assert_not_called() + pinecone_mock.Index.assert_called_with("custom_index_name") + + @patch("embedchain.vectordb.pinecone.pinecone") + def test_custom_index_name_creation(self, pinecone_mock): + """Test custom index name is created if it doesn't exists already""" + pinecone_mock.list_indexes.return_value = [] + db_config = PineconeDBConfig(index_name="custom_index_name") + _ = PineconeDB(config=db_config) + + pinecone_mock.list_indexes.assert_called_once() + pinecone_mock.create_index.assert_called_once() + pinecone_mock.Index.assert_called_with("custom_index_name") + + @patch("embedchain.vectordb.pinecone.pinecone") + def test_default_index_name_is_used(self, pinecone_mock): + """Test default index name is used if custom index name is not provided""" + db_config = PineconeDBConfig(collection_name="my-collection") + _ = PineconeDB(config=db_config) + + pinecone_mock.list_indexes.assert_called_once() + pinecone_mock.create_index.assert_called_once() + pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")