[Feature] Pinecone Vector DB support (#723)
This commit is contained in:
18
embedchain/config/vectordb/pinecone.py
Normal file
18
embedchain/config/vectordb/pinecone.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDbConfig(BaseVectorDbConfig):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: Optional[str] = None,
|
||||
dir: Optional[str] = None,
|
||||
dimension: Optional[int] = 1536,
|
||||
metric: Optional[str] = "cosine",
|
||||
):
|
||||
self.dimension = dimension
|
||||
self.metric = metric
|
||||
super().__init__(collection_name=collection_name, dir=dir)
|
||||
@@ -339,7 +339,6 @@ class EmbedChain(JSONSerializable):
|
||||
metadatas = embeddings_data["metadatas"]
|
||||
ids = embeddings_data["ids"]
|
||||
new_doc_id = embeddings_data["doc_id"]
|
||||
|
||||
if existing_doc_id and existing_doc_id == new_doc_id:
|
||||
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
||||
return [], [], [], 0
|
||||
@@ -404,7 +403,6 @@ class EmbedChain(JSONSerializable):
|
||||
skip_embedding=(chunker.data_type == DataType.IMAGES),
|
||||
)
|
||||
count_new_chunks = self.db.count() - chunks_before_addition
|
||||
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
|
||||
return list(documents), metadatas, ids, count_new_chunks
|
||||
|
||||
def _format_result(self, results):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
@@ -74,9 +74,7 @@ class ElasticsearchDB(BaseVectorDB):
|
||||
def _get_or_create_collection(self, name):
|
||||
"""Note: nothing to return here. Discuss later"""
|
||||
|
||||
def get(
|
||||
self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
|
||||
) -> Set[str]:
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
|
||||
180
embedchain/vectordb/pineconedb.py
Normal file
180
embedchain/vectordb/pineconedb.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
import pinecone
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Pinecone requires extra dependencies. Install with `pip install --upgrade embedchain[pinecone]`"
|
||||
) from None
|
||||
|
||||
from embedchain.config.vectordb.pinecone import PineconeDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PineconeDb(BaseVectorDB):
|
||||
BATCH_SIZE = 100
|
||||
|
||||
"""
|
||||
Pinecone as vector database
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PineconeDbConfig] = None,
|
||||
):
|
||||
"""Pinecone as vector database.
|
||||
|
||||
:param config: Pinecone database config, defaults to None
|
||||
:type config: PineconeDbConfig, optional
|
||||
:raises ValueError: No config provided
|
||||
"""
|
||||
if config is None:
|
||||
self.config = PineconeDbConfig()
|
||||
else:
|
||||
if not isinstance(config, PineconeDbConfig):
|
||||
raise TypeError(
|
||||
"config is not a `PineconeDbConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
self.config = config
|
||||
self.client = self._setup_pinecone_index()
|
||||
# Call parent init here because embedder is needed
|
||||
super().__init__(config=self.config)
|
||||
|
||||
def _initialize(self):
|
||||
"""
|
||||
This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
|
||||
"""
|
||||
if not self.embedder:
|
||||
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
||||
|
||||
# Loads the Pinecone index or creates it if not present.
|
||||
def _setup_pinecone_index(self):
|
||||
pinecone.init(
|
||||
api_key=os.environ.get("PINECONE_API_KEY"),
|
||||
environment=os.environ.get("PINECONE_ENV"),
|
||||
)
|
||||
self.index_name = self._get_index_name()
|
||||
indexes = pinecone.list_indexes()
|
||||
if indexes is None or self.index_name not in indexes:
|
||||
pinecone.create_index(name=self.index_name, metric=self.config.metric, dimension=self.config.dimension)
|
||||
return pinecone.Index(self.index_name)
|
||||
|
||||
def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
|
||||
"""
|
||||
Get existing doc ids present in vector database
|
||||
|
||||
:param ids: _list of doc ids to check for existence
|
||||
:type ids: List[str]
|
||||
:param where: to filter data
|
||||
:type where: Dict[str, any]
|
||||
:return: ids
|
||||
:rtype: Set[str]
|
||||
"""
|
||||
existing_ids = list()
|
||||
if ids is not None:
|
||||
for i in range(0, len(ids), 1000):
|
||||
result = self.client.fetch(ids=ids[i : i + 1000])
|
||||
batch_existing_ids = list(result.get("vectors").keys())
|
||||
existing_ids.extend(batch_existing_ids)
|
||||
|
||||
return {"ids": existing_ids}
|
||||
|
||||
def add(
|
||||
self,
|
||||
embeddings: List[List[float]],
|
||||
documents: List[str],
|
||||
metadatas: List[object],
|
||||
ids: List[str],
|
||||
skip_embedding: bool,
|
||||
):
|
||||
"""add data in vector database
|
||||
|
||||
:param documents: list of texts to add
|
||||
:type documents: List[str]
|
||||
:param metadatas: list of metadata associated with docs
|
||||
:type metadatas: List[object]
|
||||
:param ids: ids of docs
|
||||
:type ids: List[str]
|
||||
"""
|
||||
docs = []
|
||||
if embeddings is None:
|
||||
embeddings = self.embedder.embedding_fn(documents)
|
||||
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
|
||||
metadata["text"] = text
|
||||
docs.append(
|
||||
{
|
||||
"id": id,
|
||||
"values": embedding,
|
||||
"metadata": copy.deepcopy(metadata),
|
||||
}
|
||||
)
|
||||
|
||||
for i in range(0, len(docs), self.BATCH_SIZE):
|
||||
self.client.upsert(docs[i : i + self.BATCH_SIZE])
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
|
||||
:param input_query: list of query string
|
||||
:type input_query: List[str]
|
||||
:param n_results: no of similar documents to fetch from database
|
||||
:type n_results: int
|
||||
:param where: Optional. to filter data
|
||||
:type where: Dict[str, any]
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
"""
|
||||
if not skip_embedding:
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
else:
|
||||
query_vector = input_query
|
||||
contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
||||
embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
|
||||
return embeddings
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
"""
|
||||
Set the name of the collection. A collection is an isolated space for vectors.
|
||||
|
||||
:param name: Name of the collection.
|
||||
:type name: str
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("Collection name must be a string")
|
||||
self.config.collection_name = name
|
||||
|
||||
def count(self) -> int:
|
||||
"""
|
||||
Count number of documents/chunks embedded in the database.
|
||||
|
||||
:return: number of documents
|
||||
:rtype: int
|
||||
"""
|
||||
return self.client.describe_index_stats()["total_vector_count"]
|
||||
|
||||
def _get_or_create_db(self):
|
||||
"""Called during initialization"""
|
||||
return self.client
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the database. Deletes all embeddings irreversibly.
|
||||
"""
|
||||
# Delete all data from the database
|
||||
pinecone.delete_index(self.index_name)
|
||||
self._setup_pinecone_index()
|
||||
|
||||
# Pinecone only allows alphanumeric characters and "-" in the index name
|
||||
def _get_index_name(self) -> str:
|
||||
"""Get the Pinecone index for a collection
|
||||
|
||||
:return: Pinecone index
|
||||
:rtype: str
|
||||
"""
|
||||
return f"{self.config.collection_name}-{self.config.dimension}".lower().replace("_", "-")
|
||||
@@ -112,6 +112,7 @@ discord = { version = "^2.3.2", optional = true }
|
||||
slack-sdk = { version = "3.21.3", optional = true }
|
||||
cohere = { version = "^4.27", optional= true }
|
||||
docx2txt = "^0.8"
|
||||
pinecone-client = "^2.2.4"
|
||||
unstructured = {extras = ["local-inference"], version = "^0.10.18"}
|
||||
pillow = { version = "10.0.1", optional = true }
|
||||
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
|
||||
@@ -142,6 +143,7 @@ poe = ["fastapi-poe"]
|
||||
discord = ["discord"]
|
||||
slack = ["slack-sdk", "flask"]
|
||||
whatsapp = ["twilio", "flask"]
|
||||
pinecone = ["pinecone-client"]
|
||||
images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
|
||||
huggingface_hub=["huggingface_hub"]
|
||||
cohere = ["cohere"]
|
||||
|
||||
106
tests/vectordb/test_pinecone_db.py
Normal file
106
tests/vectordb/test_pinecone_db.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.vectordb.pineconedb import PineconeDb
|
||||
|
||||
|
||||
class TestPineconeDb:
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
def test_init(self, pinecone_mock):
|
||||
"""Test that the PineconeDb can be initialized."""
|
||||
# Create a PineconeDb instance
|
||||
PineconeDb()
|
||||
|
||||
# Assert that the Pinecone client was initialized
|
||||
pinecone_mock.init.assert_called_once()
|
||||
pinecone_mock.list_indexes.assert_called_once()
|
||||
pinecone_mock.Index.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
def test_set_embedder(self, pinecone_mock):
|
||||
"""Test that the embedder can be set."""
|
||||
|
||||
# Set the embedder
|
||||
embedder = BaseEmbedder()
|
||||
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=embedder)
|
||||
|
||||
# Assert that the embedder was set
|
||||
assert db.embedder == embedder
|
||||
pinecone_mock.init.assert_called_once()
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
def test_add_documents(self, pinecone_mock):
|
||||
"""Test that documents can be added to the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
|
||||
embedding_function = mock.Mock()
|
||||
base_embedder = BaseEmbedder()
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
vectors = [[0, 0, 0], [1, 1, 1]]
|
||||
embedding_function.return_value = vectors
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
|
||||
# Add some documents to the database
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
ids = ["doc1", "doc2"]
|
||||
db.add(vectors, documents, metadatas, ids, True)
|
||||
|
||||
expected_pinecone_upsert_args = [
|
||||
{"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]},
|
||||
{"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]},
|
||||
]
|
||||
# Assert that the Pinecone client was called to upsert the documents
|
||||
pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
def test_query_documents(self, pinecone_mock):
|
||||
"""Test that documents can be queried from the database."""
|
||||
pinecone_client_mock = pinecone_mock.Index.return_value
|
||||
|
||||
embedding_function = mock.Mock()
|
||||
base_embedder = BaseEmbedder()
|
||||
base_embedder.set_embedding_fn(embedding_function)
|
||||
vectors = [[0, 0, 0]]
|
||||
embedding_function.return_value = vectors
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=base_embedder)
|
||||
|
||||
# Query the database for documents that are similar to "document"
|
||||
input_query = ["document"]
|
||||
n_results = 1
|
||||
db.query(input_query, n_results, where={}, skip_embedding=False)
|
||||
|
||||
# Assert that the Pinecone client was called to query the database
|
||||
pinecone_client_mock.query.assert_called_once_with(
|
||||
vector=db.embedder.embedding_fn(input_query)[0], top_k=n_results, filter={}, include_metadata=True
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.pineconedb.pinecone")
|
||||
def test_reset(self, pinecone_mock):
|
||||
"""Test that the database can be reset."""
|
||||
# Create a PineconeDb instance
|
||||
db = PineconeDb()
|
||||
app_config = AppConfig(collect_metrics=False)
|
||||
App(config=app_config, db=db, embedder=BaseEmbedder())
|
||||
|
||||
# Reset the database
|
||||
db.reset()
|
||||
|
||||
# Assert that the Pinecone client was called to delete the index
|
||||
pinecone_mock.delete_index.assert_called_once_with(db.index_name)
|
||||
|
||||
# Assert that the index is recreated
|
||||
pinecone_mock.Index.assert_called_with(db.index_name)
|
||||
Reference in New Issue
Block a user