enable using custom Pinecone index name (#1172)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Peter Jausovec
2024-01-25 00:00:10 -08:00
committed by GitHub
parent b7d365119c
commit 446d0975aa
4 changed files with 49 additions and 17 deletions

View File

@@ -189,6 +189,11 @@ vectordb:
</CodeGroup> </CodeGroup>
<br />
<Note>
You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
</Note>
## Qdrant ## Qdrant
In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/). In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/).

View File

@@ -9,6 +9,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
def __init__( def __init__(
self, self,
collection_name: Optional[str] = None, collection_name: Optional[str] = None,
index_name: Optional[str] = None,
dir: Optional[str] = None, dir: Optional[str] = None,
vector_dimension: int = 1536, vector_dimension: int = 1536,
metric: Optional[str] = "cosine", metric: Optional[str] = "cosine",
@@ -17,4 +18,5 @@ class PineconeDBConfig(BaseVectorDbConfig):
self.metric = metric self.metric = metric
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension
self.extra_params = extra_params self.extra_params = extra_params
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
super().__init__(collection_name=collection_name, dir=dir) super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -53,20 +53,21 @@ class PineconeDB(BaseVectorDB):
if not self.embedder: if not self.embedder:
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
# Loads the Pinecone index or creates it if not present.
def _setup_pinecone_index(self): def _setup_pinecone_index(self):
"""
Loads the Pinecone index or creates it if not present.
"""
pinecone.init( pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"), api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"), environment=os.environ.get("PINECONE_ENV"),
**self.config.extra_params, **self.config.extra_params,
) )
self.index_name = self._get_index_name()
indexes = pinecone.list_indexes() indexes = pinecone.list_indexes()
if indexes is None or self.index_name not in indexes: if indexes is None or self.config.index_name not in indexes:
pinecone.create_index( pinecone.create_index(
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
) )
return pinecone.Index(self.index_name) return pinecone.Index(self.config.index_name)
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None): def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
""" """
@@ -193,18 +194,9 @@ class PineconeDB(BaseVectorDB):
Resets the database. Deletes all embeddings irreversibly. Resets the database. Deletes all embeddings irreversibly.
""" """
# Delete all data from the database # Delete all data from the database
pinecone.delete_index(self.index_name) pinecone.delete_index(self.config.index_name)
self._setup_pinecone_index() self._setup_pinecone_index()
# Pinecone only allows alphanumeric characters and "-" in the index name
def _get_index_name(self) -> str:
"""Get the Pinecone index for a collection
:return: Pinecone index
:rtype: str
"""
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
@staticmethod @staticmethod
def _generate_filter(where: dict): def _generate_filter(where: dict):
query = {} query = {}

View File

@@ -3,6 +3,7 @@ from unittest.mock import patch
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pinecone import PineconeDB from embedchain.vectordb.pinecone import PineconeDB
@@ -100,7 +101,39 @@ class TestPinecone:
db.reset() db.reset()
# Assert that the Pinecone client was called to delete the index # Assert that the Pinecone client was called to delete the index
pinecone_mock.delete_index.assert_called_once_with(db.index_name) pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
# Assert that the index is recreated # Assert that the index is recreated
pinecone_mock.Index.assert_called_with(db.index_name) pinecone_mock.Index.assert_called_with(db.config.index_name)
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_if_it_exists(self, pinecone_mock):
"""Tests custom index name is used if it exists"""
pinecone_mock.list_indexes.return_value = ["custom_index_name"]
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_not_called()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_custom_index_name_creation(self, pinecone_mock):
"""Test custom index name is created if it doesn't exists already"""
pinecone_mock.list_indexes.return_value = []
db_config = PineconeDBConfig(index_name="custom_index_name")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with("custom_index_name")
@patch("embedchain.vectordb.pinecone.pinecone")
def test_default_index_name_is_used(self, pinecone_mock):
"""Test default index name is used if custom index name is not provided"""
db_config = PineconeDBConfig(collection_name="my-collection")
_ = PineconeDB(config=db_config)
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.create_index.assert_called_once()
pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")