[Feature] Add support for hybrid search for pinecone vector database (#1259)

This commit is contained in:
Deshraj Yadav
2024-02-15 13:20:14 -08:00
committed by GitHub
parent 0766a44ccf
commit 38b4e06963
18 changed files with 470 additions and 326 deletions

View File

@@ -1,3 +1,4 @@
import logging
import os
from typing import Optional, Union
@@ -8,6 +9,8 @@ except ImportError:
"Pinecone requires extra dependencies. Install with `pip install --upgrade 'embedchain[pinecone]'`"
) from None
from pinecone_text.sparse import BM25Encoder
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils.misc import chunks
@@ -42,6 +45,14 @@ class PineconeDB(BaseVectorDB):
)
self.config = config
self._setup_pinecone_index()
# Setup BM25Encoder if sparse vectors are to be used
self.bm25_encoder = None
if self.config.hybrid_search:
# TODO: Add support for fitting BM25Encoder on any corpus
logging.info("Initializing BM25Encoder for sparse vectors..")
self.bm25_encoder = BM25Encoder.default()
# Call parent init here because embedder is needed
super().__init__(config=self.config)
@@ -119,12 +130,17 @@ class PineconeDB(BaseVectorDB):
docs = []
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
# Insert sparse vectors as well if the user wants to do the hybrid search
sparse_vector_dict = (
{"sparse_values": self.bm25_encoder.encode_documents(text)} if self.bm25_encoder else {}
)
docs.append(
{
"id": id,
"values": embedding,
"metadata": {**metadata, "text": text},
}
**sparse_vector_dict,
},
)
for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches"):
@@ -159,14 +175,19 @@ class PineconeDB(BaseVectorDB):
query_filter["app_id"] = {"$eq": app_id}
query_vector = self.embedder.embedding_fn([input_query])[0]
data = self.pinecone_index.query(
vector=query_vector,
filter=query_filter,
top_k=n_results,
include_metadata=True,
params = {
"vector": query_vector,
"filter": query_filter,
"top_k": n_results,
"include_metadata": True,
**kwargs,
)
}
if self.bm25_encoder:
sparse_query_vector = self.bm25_encoder.encode_queries(input_query)
params["sparse_vector"] = sparse_query_vector
data = self.pinecone_index.query(**params)
return [
(metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text")
for doc in data.get("matches", [])