enable using custom Pinecone index name (#1172)
Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
@@ -189,6 +189,11 @@ vectordb:
|
|||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
|
<br />
|
||||||
|
<Note>
|
||||||
|
You can optionally provide `index_name` as a config param in yaml file to specify the index name. If not provided, the index name will be `{collection_name}-{vector_dimension}`.
|
||||||
|
</Note>
|
||||||
|
|
||||||
## Qdrant
|
## Qdrant
|
||||||
|
|
||||||
In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/).
|
In order to use Qdrant as a vector database, set the environment variables `QDRANT_URL` and `QDRANT_API_KEY` which you can find on [Qdrant Dashboard](https://cloud.qdrant.io/).
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
collection_name: Optional[str] = None,
|
collection_name: Optional[str] = None,
|
||||||
|
index_name: Optional[str] = None,
|
||||||
dir: Optional[str] = None,
|
dir: Optional[str] = None,
|
||||||
vector_dimension: int = 1536,
|
vector_dimension: int = 1536,
|
||||||
metric: Optional[str] = "cosine",
|
metric: Optional[str] = "cosine",
|
||||||
@@ -17,4 +18,5 @@ class PineconeDBConfig(BaseVectorDbConfig):
|
|||||||
self.metric = metric
|
self.metric = metric
|
||||||
self.vector_dimension = vector_dimension
|
self.vector_dimension = vector_dimension
|
||||||
self.extra_params = extra_params
|
self.extra_params = extra_params
|
||||||
|
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
|
||||||
super().__init__(collection_name=collection_name, dir=dir)
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
|
|||||||
@@ -53,20 +53,21 @@ class PineconeDB(BaseVectorDB):
|
|||||||
if not self.embedder:
|
if not self.embedder:
|
||||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||||
|
|
||||||
# Loads the Pinecone index or creates it if not present.
|
|
||||||
def _setup_pinecone_index(self):
|
def _setup_pinecone_index(self):
|
||||||
|
"""
|
||||||
|
Loads the Pinecone index or creates it if not present.
|
||||||
|
"""
|
||||||
pinecone.init(
|
pinecone.init(
|
||||||
api_key=os.environ.get("PINECONE_API_KEY"),
|
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||||
environment=os.environ.get("PINECONE_ENV"),
|
environment=os.environ.get("PINECONE_ENV"),
|
||||||
**self.config.extra_params,
|
**self.config.extra_params,
|
||||||
)
|
)
|
||||||
self.index_name = self._get_index_name()
|
|
||||||
indexes = pinecone.list_indexes()
|
indexes = pinecone.list_indexes()
|
||||||
if indexes is None or self.index_name not in indexes:
|
if indexes is None or self.config.index_name not in indexes:
|
||||||
pinecone.create_index(
|
pinecone.create_index(
|
||||||
name=self.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
|
name=self.config.index_name, metric=self.config.metric, dimension=self.config.vector_dimension
|
||||||
)
|
)
|
||||||
return pinecone.Index(self.index_name)
|
return pinecone.Index(self.config.index_name)
|
||||||
|
|
||||||
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
@@ -193,18 +194,9 @@ class PineconeDB(BaseVectorDB):
|
|||||||
Resets the database. Deletes all embeddings irreversibly.
|
Resets the database. Deletes all embeddings irreversibly.
|
||||||
"""
|
"""
|
||||||
# Delete all data from the database
|
# Delete all data from the database
|
||||||
pinecone.delete_index(self.index_name)
|
pinecone.delete_index(self.config.index_name)
|
||||||
self._setup_pinecone_index()
|
self._setup_pinecone_index()
|
||||||
|
|
||||||
# Pinecone only allows alphanumeric characters and "-" in the index name
|
|
||||||
def _get_index_name(self) -> str:
|
|
||||||
"""Get the Pinecone index for a collection
|
|
||||||
|
|
||||||
:return: Pinecone index
|
|
||||||
:rtype: str
|
|
||||||
"""
|
|
||||||
return f"{self.config.collection_name}-{self.config.vector_dimension}".lower().replace("_", "-")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_filter(where: dict):
|
def _generate_filter(where: dict):
|
||||||
query = {}
|
query = {}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig
|
from embedchain.config import AppConfig
|
||||||
|
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.vectordb.pinecone import PineconeDB
|
from embedchain.vectordb.pinecone import PineconeDB
|
||||||
|
|
||||||
@@ -100,7 +101,39 @@ class TestPinecone:
|
|||||||
db.reset()
|
db.reset()
|
||||||
|
|
||||||
# Assert that the Pinecone client was called to delete the index
|
# Assert that the Pinecone client was called to delete the index
|
||||||
pinecone_mock.delete_index.assert_called_once_with(db.index_name)
|
pinecone_mock.delete_index.assert_called_once_with(db.config.index_name)
|
||||||
|
|
||||||
# Assert that the index is recreated
|
# Assert that the index is recreated
|
||||||
pinecone_mock.Index.assert_called_with(db.index_name)
|
pinecone_mock.Index.assert_called_with(db.config.index_name)
|
||||||
|
|
||||||
|
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||||
|
def test_custom_index_name_if_it_exists(self, pinecone_mock):
|
||||||
|
"""Tests custom index name is used if it exists"""
|
||||||
|
pinecone_mock.list_indexes.return_value = ["custom_index_name"]
|
||||||
|
db_config = PineconeDBConfig(index_name="custom_index_name")
|
||||||
|
_ = PineconeDB(config=db_config)
|
||||||
|
|
||||||
|
pinecone_mock.list_indexes.assert_called_once()
|
||||||
|
pinecone_mock.create_index.assert_not_called()
|
||||||
|
pinecone_mock.Index.assert_called_with("custom_index_name")
|
||||||
|
|
||||||
|
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||||
|
def test_custom_index_name_creation(self, pinecone_mock):
|
||||||
|
"""Test custom index name is created if it doesn't exists already"""
|
||||||
|
pinecone_mock.list_indexes.return_value = []
|
||||||
|
db_config = PineconeDBConfig(index_name="custom_index_name")
|
||||||
|
_ = PineconeDB(config=db_config)
|
||||||
|
|
||||||
|
pinecone_mock.list_indexes.assert_called_once()
|
||||||
|
pinecone_mock.create_index.assert_called_once()
|
||||||
|
pinecone_mock.Index.assert_called_with("custom_index_name")
|
||||||
|
|
||||||
|
@patch("embedchain.vectordb.pinecone.pinecone")
|
||||||
|
def test_default_index_name_is_used(self, pinecone_mock):
|
||||||
|
"""Test default index name is used if custom index name is not provided"""
|
||||||
|
db_config = PineconeDBConfig(collection_name="my-collection")
|
||||||
|
_ = PineconeDB(config=db_config)
|
||||||
|
|
||||||
|
pinecone_mock.list_indexes.assert_called_once()
|
||||||
|
pinecone_mock.create_index.assert_called_once()
|
||||||
|
pinecone_mock.Index.assert_called_with(f"{db_config.collection_name}-{db_config.vector_dimension}")
|
||||||
|
|||||||
Reference in New Issue
Block a user