[Feature] Pinecone Vector DB support (#723)
This commit is contained in:
18
embedchain/config/vectordb/pinecone.py
Normal file
18
embedchain/config/vectordb/pinecone.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||||
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class PineconeDbConfig(BaseVectorDbConfig):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
collection_name: Optional[str] = None,
|
||||||
|
dir: Optional[str] = None,
|
||||||
|
dimension: Optional[int] = 1536,
|
||||||
|
metric: Optional[str] = "cosine",
|
||||||
|
):
|
||||||
|
self.dimension = dimension
|
||||||
|
self.metric = metric
|
||||||
|
super().__init__(collection_name=collection_name, dir=dir)
|
||||||
@@ -339,7 +339,6 @@ class EmbedChain(JSONSerializable):
|
|||||||
metadatas = embeddings_data["metadatas"]
|
metadatas = embeddings_data["metadatas"]
|
||||||
ids = embeddings_data["ids"]
|
ids = embeddings_data["ids"]
|
||||||
new_doc_id = embeddings_data["doc_id"]
|
new_doc_id = embeddings_data["doc_id"]
|
||||||
|
|
||||||
if existing_doc_id and existing_doc_id == new_doc_id:
|
if existing_doc_id and existing_doc_id == new_doc_id:
|
||||||
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
||||||
return [], [], [], 0
|
return [], [], [], 0
|
||||||
@@ -404,7 +403,6 @@ class EmbedChain(JSONSerializable):
|
|||||||
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
||||||
)
|
)
|
||||||
count_new_chunks = self.db.count() - chunks_before_addition
|
count_new_chunks = self.db.count() - chunks_before_addition
|
||||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
|
||||||
return list(documents), metadatas, ids, count_new_chunks
|
return list(documents), metadatas, ids, count_new_chunks
|
||||||
|
|
||||||
def _format_result(self, results):
|
def _format_result(self, results):
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Set
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
@@ -74,9 +74,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
def _get_or_create_collection(self, name):
|
def _get_or_create_collection(self, name):
|
||||||
"""Note: nothing to return here. Discuss later"""
|
"""Note: nothing to return here. Discuss later"""
|
||||||
|
|
||||||
def get(
|
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||||
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
|
|
||||||
) -> Set[str]:
|
|
||||||
"""
|
"""
|
||||||
Get existing doc ids present in vector database
|
Get existing doc ids present in vector database
|
||||||
|
|
||||||
|
|||||||
180
embedchain/vectordb/pineconedb.py
Normal file
180
embedchain/vectordb/pineconedb.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import copy
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pinecone
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
from embedchain.config.vectordb.pinecone import PineconeDbConfig
|
||||||
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class PineconeDb(BaseVectorDB):
|
||||||
|
BATCH_SIZE = 100
|
||||||
|
|
||||||
|
"""
|
||||||
|
Pinecone as vector database
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[PineconeDbConfig] = None,
|
||||||
|
):
|
||||||
|
"""Pinecone as vector database.
|
||||||
|
|
||||||
|
:param config: Pinecone database config, defaults to None
|
||||||
|
:type config: PineconeDbConfig, optional
|
||||||
|
:raises ValueError: No config provided
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
self.config = PineconeDbConfig()
|
||||||
|
else:
|
||||||
|
if not isinstance(config, PineconeDbConfig):
|
||||||
|
raise TypeError(
|
||||||
|
"config is not a `PineconeDbConfig` instance. "
|
||||||
|
"Please make sure the type is right and that you are passing an instance."
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
self.client = self._setup_pinecone_index()
|
||||||
|
# Call parent init here because embedder is needed
|
||||||
|
super().__init__(config=self.config)
|
||||||
|
|
||||||
|
def _initialize(self):
|
||||||
|
"""
|
||||||
|
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||||
|
"""
|
||||||
|
if not self.embedder:
|
||||||
|
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||||
|
|
||||||
|
# Loads the Pinecone index or creates it if not present.
|
||||||
|
def _setup_pinecone_index(self):
|
||||||
|
pinecone.init(
|
||||||
|
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||||
|
environment=os.environ.get("PINECONE_ENV"),
|
||||||
|
)
|
||||||
|
self.index_name = self._get_index_name()
|
||||||
|
indexes = pinecone.list_indexes()
|
||||||
|
if indexes is None or self.index_name not in indexes:
|
||||||
|
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
|
||||||
|
return pinecone.Index(self.index_name)
|
||||||
|
|
||||||
|
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Get existing doc ids present in vector database
|
||||||
|
|
||||||
|
:param ids: _list of doc ids to check for existence
|
||||||
|
:type ids: List[str]
|
||||||
|
:param where: to filter data
|
||||||
|
:type where: Dict[str, any]
|
||||||
|
:return: ids
|
||||||
|
:rtype: Set[str]
|
||||||
|
"""
|
||||||
|
existing_ids = list()
|
||||||
|
if ids is not None:
|
||||||
|
for i in range(0, len(ids), 1000):
|
||||||
|
result = self.client.fetch(ids=ids[i : i + 1000])
|
||||||
|
batch_existing_ids = list(result.get("vectors").keys())
|
||||||
|
existing_ids.extend(batch_existing_ids)
|
||||||
|
|
||||||
|
return {"ids": existing_ids}
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
documents: List[str],
|
||||||
|
metadatas: List[object],
|
||||||
|
ids: List[str],
|
||||||
|
skip_embedding: bool,
|
||||||
|
):
|
||||||
|
"""add data in vector database
|
||||||
|
|
||||||
|
:param documents: list of texts to add
|
||||||
|
:type documents: List[str]
|
||||||
|
:param metadatas: list of metadata associated with docs
|
||||||
|
:type metadatas: List[object]
|
||||||
|
:param ids: ids of docs
|
||||||
|
:type ids: List[str]
|
||||||
|
"""
|
||||||
|
docs = []
|
||||||
|
if embeddings is None:
|
||||||
|
embeddings = self.embedder.embedding_fn(documents)
|
||||||
|
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||||
|
metadata["text"] = text
|
||||||
|
docs.append(
|
||||||
|
{
|
||||||
|
"id": id,
|
||||||
|
"values": embedding,
|
||||||
|
"metadata": copy.deepcopy(metadata),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, len(docs), self.BATCH_SIZE):
|
||||||
|
self.client.upsert(docs[i : i + self.BATCH_SIZE])
|
||||||
|
|
||||||
|
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||||
|
"""
|
||||||
|
query contents from vector database based on vector similarity
|
||||||
|
|
||||||
|
:param input_query: list of query string
|
||||||
|
:type input_query: List[str]
|
||||||
|
:param n_results: no of similar documents to fetch from database
|
||||||
|
:type n_results: int
|
||||||
|
:param where: Optional. to filter data
|
||||||
|
:type where: Dict[str, any]
|
||||||
|
:return: Database contents that are the result of the query
|
||||||
|
:rtype: List[str]
|
||||||
|
"""
|
||||||
|
if not skip_embedding:
|
||||||
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||||
|
else:
|
||||||
|
query_vector = input_query
|
||||||
|
contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
||||||
|
embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def set_collection_name(self, name: str):
|
||||||
|
"""
|
||||||
|
Set the name of the collection. A collection is an isolated space for vectors.
|
||||||
|
|
||||||
|
:param name: Name of the collection.
|
||||||
|
:type name: str
|
||||||
|
"""
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise TypeError("Collection name must be a string")
|
||||||
|
self.config.collection_name = name
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
"""
|
||||||
|
Count number of documents/chunks embedded in the database.
|
||||||
|
|
||||||
|
:return: number of documents
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return self.client.describe_index_stats()["total_vector_count"]
|
||||||
|
|
||||||
|
def _get_or_create_db(self):
|
||||||
|
"""Called during initialization"""
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Resets the database. Deletes all embeddings irreversibly.
|
||||||
|
"""
|
||||||
|
# Delete all data from the database
|
||||||
|
pinecone.delete_index(self.index_name)
|
||||||
|
self._setup_pinecone_index()
|
||||||
|
|
||||||
|
# Pinecone only allows alphanumeric characters and "-" in the index name
|
||||||
|
def _get_index_name(self) -> str:
|
||||||
|
"""Get the Pinecone index for a collection
|
||||||
|
|
||||||
|
:return: Pinecone index
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
|
||||||
@@ -112,6 +112,7 @@ discord = { version = "^2.3.2", optional = true }
|
|||||||
slack-sdk = { version = "3.21.3", optional = true }
|
slack-sdk = { version = "3.21.3", optional = true }
|
||||||
cohere = { version = "^4.27", optional= true }
|
cohere = { version = "^4.27", optional= true }
|
||||||
docx2txt = "^0.8"
|
docx2txt = "^0.8"
|
||||||
|
pinecone-client = "^2.2.4"
|
||||||
unstructured = {extras = ["local-inference"], version = "^0.10.18"}
|
unstructured = {extras = ["local-inference"], version = "^0.10.18"}
|
||||||
pillow = { version = "10.0.1", optional = true }
|
pillow = { version = "10.0.1", optional = true }
|
||||||
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
|
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
|
||||||
@@ -142,6 +143,7 @@ poe = ["fastapi-poe"]
|
|||||||
discord = ["discord"]
|
discord = ["discord"]
|
||||||
slack = ["slack-sdk", "flask"]
|
slack = ["slack-sdk", "flask"]
|
||||||
whatsapp = ["twilio", "flask"]
|
whatsapp = ["twilio", "flask"]
|
||||||
|
pinecone = ["pinecone-client"]
|
||||||
images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
|
images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
|
||||||
huggingface_hub=["huggingface_hub"]
|
huggingface_hub=["huggingface_hub"]
|
||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
|
|||||||
106
tests/vectordb/test_pinecone_db.py
Normal file
106
tests/vectordb/test_pinecone_db.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.config import AppConfig
|
||||||
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
|
from embedchain.vectordb.pineconedb import PineconeDb
|
||||||
|
|
||||||
|
|
||||||
|
class TestPineconeDb:
|
||||||
|
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||||
|
def test_init(self, pinecone_mock):
|
||||||
|
"""Test that the PineconeDb can be initialized."""
|
||||||
|
# Create a PineconeDb instance
|
||||||
|
PineconeDb()
|
||||||
|
|
||||||
|
# Assert that the Pinecone client was initialized
|
||||||
|
pinecone_mock.init.assert_called_once()
|
||||||
|
pinecone_mock.list_indexes.assert_called_once()
|
||||||
|
pinecone_mock.Index.assert_called_once()
|
||||||
|
|
||||||
|
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||||
|
def test_set_embedder(self, pinecone_mock):
|
||||||
|
"""Test that the embedder can be set."""
|
||||||
|
|
||||||
|
# Set the embedder
|
||||||
|
embedder = BaseEmbedder()
|
||||||
|
|
||||||
|
# Create a PineconeDb instance
|
||||||
|
db = PineconeDb()
|
||||||
|
app_config = AppConfig(collect_metrics=False)
|
||||||
|
App(config=app_config, db=db, embedder=embedder)
|
||||||
|
|
||||||
|
# Assert that the embedder was set
|
||||||
|
assert db.embedder == embedder
|
||||||
|
pinecone_mock.init.assert_called_once()
|
||||||
|
|
||||||
|
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||||
|
def test_add_documents(self, pinecone_mock):
|
||||||
|
"""Test that documents can be added to the database."""
|
||||||
|
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||||
|
|
||||||
|
embedding_function = mock.Mock()
|
||||||
|
base_embedder = BaseEmbedder()
|
||||||
|
base_embedder.set_embedding_fn(embedding_function)
|
||||||
|
vectors = [[0, 0, 0], [1, 1, 1]]
|
||||||
|
embedding_function.return_value = vectors
|
||||||
|
# Create a PineconeDb instance
|
||||||
|
db = PineconeDb()
|
||||||
|
app_config = AppConfig(collect_metrics=False)
|
||||||
|
App(config=app_config, db=db, embedder=base_embedder)
|
||||||
|
|
||||||
|
# Add some documents to the database
|
||||||
|
documents = ["This is a document.", "This is another document."]
|
||||||
|
metadatas = [{}, {}]
|
||||||
|
ids = ["doc1", "doc2"]
|
||||||
|
db.add(vectors, documents, metadatas, ids, True)
|
||||||
|
|
||||||
|
expected_pinecone_upsert_args = [
|
||||||
|
{"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]},
|
||||||
|
{"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]},
|
||||||
|
]
|
||||||
|
# Assert that the Pinecone client was called to upsert the documents
|
||||||
|
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
|
||||||
|
|
||||||
|
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||||
|
def test_query_documents(self, pinecone_mock):
|
||||||
|
"""Test that documents can be queried from the database."""
|
||||||
|
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||||
|
|
||||||
|
embedding_function = mock.Mock()
|
||||||
|
base_embedder = BaseEmbedder()
|
||||||
|
base_embedder.set_embedding_fn(embedding_function)
|
||||||
|
vectors = [[0, 0, 0]]
|
||||||
|
embedding_function.return_value = vectors
|
||||||
|
# Create a PineconeDb instance
|
||||||
|
db = PineconeDb()
|
||||||
|
app_config = AppConfig(collect_metrics=False)
|
||||||
|
App(config=app_config, db=db, embedder=base_embedder)
|
||||||
|
|
||||||
|
# Query the database for documents that are similar to "document"
|
||||||
|
input_query = ["document"]
|
||||||
|
n_results = 1
|
||||||
|
db.query(input_query, n_results, where={}, skip_embedding=False)
|
||||||
|
|
||||||
|
# Assert that the Pinecone client was called to query the database
|
||||||
|
pinecone_client_mock.query.assert_called_once_with(
|
||||||
|
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||||
|
def test_reset(self, pinecone_mock):
|
||||||
|
"""Test that the database can be reset."""
|
||||||
|
# Create a PineconeDb instance
|
||||||
|
db = PineconeDb()
|
||||||
|
app_config = AppConfig(collect_metrics=False)
|
||||||
|
App(config=app_config, db=db, embedder=BaseEmbedder())
|
||||||
|
|
||||||
|
# Reset the database
|
||||||
|
db.reset()
|
||||||
|
|
||||||
|
# Assert that the Pinecone client was called to delete the index
|
||||||
|
pinecone_mock.delete_index.assert_called_once_with(db.index_name)
|
||||||
|
|
||||||
|
# Assert that the index is recreated
|
||||||
|
pinecone_mock.Index.assert_called_with(db.index_name)
|
||||||
Reference in New Issue
Block a user