[Feature] Pinecone Vector DB support (#723)

This commit is contained in:
Rupesh Bansal
2023-10-15 14:24:07 +05:30
committed by GitHub
parent 5ec12212e4
commit a7a61fae1d
6 changed files with 308 additions and 6 deletions

View File

@@ -0,0 +1,18 @@
from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class PineconeDbConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
dimension: Optional[int] = 1536,
metric: Optional[str] = "cosine",
):
self.dimension = dimension
self.metric = metric
super().__init__(collection_name=collection_name, dir=dir)

View File

@@ -339,7 +339,6 @@ class EmbedChain(JSONSerializable):
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
new_doc_id = embeddings_data["doc_id"]
if existing_doc_id and existing_doc_id == new_doc_id:
print("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0
@@ -404,7 +403,6 @@ class EmbedChain(JSONSerializable):
skip_embedding=(chunker.data_type == DataType.IMAGES),
)
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks
def _format_result(self, results):

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional
try:
from elasticsearch import Elasticsearch
@@ -74,9 +74,7 @@ class ElasticsearchDB(BaseVectorDB):
def _get_or_create_collection(self, name):
"""Note: nothing to return here. Discuss later"""
def get(
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
) -> Set[str]:
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database

View File

@@ -0,0 +1,180 @@
import copy
import os
from typing import Dict, List, Optional
try:
import pinecone
except ImportError:
raise ImportError(
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
) from None
from embedchain.config.vectordb.pinecone import PineconeDbConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB
@register_deserializable
class PineconeDb(BaseVectorDB):
BATCH_SIZE = 100
"""
Pinecone as vector database
"""
def __init__(
self,
config: Optional[PineconeDbConfig] = None,
):
"""Pinecone as vector database.
:param config: Pinecone database config, defaults to None
:type config: PineconeDbConfig, optional
:raises ValueError: No config provided
"""
if config is None:
self.config = PineconeDbConfig()
else:
if not isinstance(config, PineconeDbConfig):
raise TypeError(
"config is not a `PineconeDbConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
self.config = config
self.client = self._setup_pinecone_index()
# Call parent init here because embedder is needed
super().__init__(config=self.config)
def _initialize(self):
"""
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
"""
if not self.embedder:
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
# Loads the Pinecone index or creates it if not present.
def _setup_pinecone_index(self):
pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
environment=os.environ.get("PINECONE_ENV"),
)
self.index_name = self._get_index_name()
indexes = pinecone.list_indexes()
if indexes is None or self.index_name not in indexes:
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
return pinecone.Index(self.index_name)
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
"""
Get existing doc ids present in vector database
:param ids: _list of doc ids to check for existence
:type ids: List[str]
:param where: to filter data
:type where: Dict[str, any]
:return: ids
:rtype: Set[str]
"""
existing_ids = list()
if ids is not None:
for i in range(0, len(ids), 1000):
result = self.client.fetch(ids=ids[i : i + 1000])
batch_existing_ids = list(result.get("vectors").keys())
existing_ids.extend(batch_existing_ids)
return {"ids": existing_ids}
def add(
self,
embeddings: List[List[float]],
documents: List[str],
metadatas: List[object],
ids: List[str],
skip_embedding: bool,
):
"""add data in vector database
:param documents: list of texts to add
:type documents: List[str]
:param metadatas: list of metadata associated with docs
:type metadatas: List[object]
:param ids: ids of docs
:type ids: List[str]
"""
docs = []
if embeddings is None:
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
metadata["text"] = text
docs.append(
{
"id": id,
"values": embedding,
"metadata": copy.deepcopy(metadata),
}
)
for i in range(0, len(docs), self.BATCH_SIZE):
self.client.upsert(docs[i : i + self.BATCH_SIZE])
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: List[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: Dict[str, any]
:return: Database contents that are the result of the query
:rtype: List[str]
"""
if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0]
else:
query_vector = input_query
contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
return embeddings
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name
def count(self) -> int:
"""
Count number of documents/chunks embedded in the database.
:return: number of documents
:rtype: int
"""
return self.client.describe_index_stats()["total_vector_count"]
def _get_or_create_db(self):
"""Called during initialization"""
return self.client
def reset(self):
"""
Resets the database. Deletes all embeddings irreversibly.
"""
# Delete all data from the database
pinecone.delete_index(self.index_name)
self._setup_pinecone_index()
# Pinecone only allows alphanumeric characters and "-" in the index name
def _get_index_name(self) -> str:
"""Get the Pinecone index for a collection
:return: Pinecone index
:rtype: str
"""
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")

View File

@@ -112,6 +112,7 @@ discord = { version = "^2.3.2", optional = true }
slack-sdk = { version = "3.21.3", optional = true }
cohere = { version = "^4.27", optional= true }
docx2txt = "^0.8"
pinecone-client = "^2.2.4"
unstructured = {extras = ["local-inference"], version = "^0.10.18"}
pillow = { version = "10.0.1", optional = true }
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
@@ -142,6 +143,7 @@ poe = ["fastapi-poe"]
discord = ["discord"]
slack = ["slack-sdk", "flask"]
whatsapp = ["twilio", "flask"]
pinecone = ["pinecone-client"]
images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
huggingface_hub=["huggingface_hub"]
cohere = ["cohere"]

View File

@@ -0,0 +1,106 @@
from unittest import mock
from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.pineconedb import PineconeDb
class TestPineconeDb:
@patch("embedchain.vectordb.pineconedb.pinecone")
def test_init(self, pinecone_mock):
"""Test that the PineconeDb can be initialized."""
# Create a PineconeDb instance
PineconeDb()
# Assert that the Pinecone client was initialized
pinecone_mock.init.assert_called_once()
pinecone_mock.list_indexes.assert_called_once()
pinecone_mock.Index.assert_called_once()
@patch("embedchain.vectordb.pineconedb.pinecone")
def test_set_embedder(self, pinecone_mock):
"""Test that the embedder can be set."""
# Set the embedder
embedder = BaseEmbedder()
# Create a PineconeDb instance
db = PineconeDb()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=embedder)
# Assert that the embedder was set
assert db.embedder == embedder
pinecone_mock.init.assert_called_once()
@patch("embedchain.vectordb.pineconedb.pinecone")
def test_add_documents(self, pinecone_mock):
"""Test that documents can be added to the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock()
base_embedder = BaseEmbedder()
base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0], [1, 1, 1]]
embedding_function.return_value = vectors
# Create a PineconeDb instance
db = PineconeDb()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
# Add some documents to the database
documents = ["This is a document.", "This is another document."]
metadatas = [{}, {}]
ids = ["doc1", "doc2"]
db.add(vectors, documents, metadatas, ids, True)
expected_pinecone_upsert_args = [
{"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]},
{"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]},
]
# Assert that the Pinecone client was called to upsert the documents
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
@patch("embedchain.vectordb.pineconedb.pinecone")
def test_query_documents(self, pinecone_mock):
"""Test that documents can be queried from the database."""
pinecone_client_mock = pinecone_mock.Index.return_value
embedding_function = mock.Mock()
base_embedder = BaseEmbedder()
base_embedder.set_embedding_fn(embedding_function)
vectors = [[0, 0, 0]]
embedding_function.return_value = vectors
# Create a PineconeDb instance
db = PineconeDb()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=base_embedder)
# Query the database for documents that are similar to "document"
input_query = ["document"]
n_results = 1
db.query(input_query, n_results, where={}, skip_embedding=False)
# Assert that the Pinecone client was called to query the database
pinecone_client_mock.query.assert_called_once_with(
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
)
@patch("embedchain.vectordb.pineconedb.pinecone")
def test_reset(self, pinecone_mock):
"""Test that the database can be reset."""
# Create a PineconeDb instance
db = PineconeDb()
app_config = AppConfig(collect_metrics=False)
App(config=app_config, db=db, embedder=BaseEmbedder())
# Reset the database
db.reset()
# Assert that the Pinecone client was called to delete the index
pinecone_mock.delete_index.assert_called_once_with(db.index_name)
# Assert that the index is recreated
pinecone_mock.Index.assert_called_with(db.index_name)