[Feature] Add support for metadata filtering on search API (#1245)

This commit is contained in:
Deshraj Yadav
2024-02-06 15:42:51 -08:00
committed by GitHub
parent 8fe2c3effc
commit 4afef04f26
10 changed files with 173 additions and 104 deletions

View File

@@ -79,6 +79,8 @@ class ChromaDB(BaseVectorDB):
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if where is None:
return {}
if len(where.keys()) <= 1:
return where
where_filters = []
@@ -180,9 +182,10 @@ class ChromaDB(BaseVectorDB):
self,
input_query: list[str],
n_results: int,
where: dict[str, any],
where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False,
**kwargs: Optional[dict[str, Any]],
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
Query contents from vector database based on vector similarity
@@ -193,6 +196,8 @@ class ChromaDB(BaseVectorDB):
:type n_results: int
:param where: to filter data
:type where: dict[str, Any]
:param raw_filter: Raw filter to apply
:type raw_filter: dict[str, Any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match.
@@ -200,14 +205,21 @@ class ChromaDB(BaseVectorDB):
along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
if where and raw_filter:
raise ValueError("Both `where` and `raw_filter` cannot be used together.")
where_clause = {}
if raw_filter:
where_clause = raw_filter
if where:
where_clause = self._generate_where_clause(where)
try:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=self._generate_where_clause(where),
**kwargs,
where=where_clause,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(

View File

@@ -1,4 +1,3 @@
import logging
import os
from typing import Optional, Union
@@ -99,10 +98,6 @@ class PineconeDB(BaseVectorDB):
batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids)
metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
if where is not None:
logging.warning("Filtering is not supported by Pinecone")
return {"ids": existing_ids, "metadatas": metadatas}
def add(
@@ -122,7 +117,6 @@ class PineconeDB(BaseVectorDB):
:type ids: list[str]
"""
docs = []
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
docs.append(
@@ -140,26 +134,31 @@ class PineconeDB(BaseVectorDB):
self,
input_query: list[str],
n_results: int,
where: dict[str, any],
where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False,
app_id: Optional[str] = None,
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
Query contents from vector database based on vector similarity.
Args:
input_query (list[str]): List of query strings.
n_results (int): Number of similar documents to fetch from the database.
where (dict[str, any], optional): Filter criteria for the search.
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
citations (bool, optional): Flag to return context along with metadata. Defaults to False.
app_id (str, optional): Application ID to be passed to Pinecone.
Returns:
Union[list[tuple[str, dict]], list[str]]: List of document contexts, optionally with metadata.
"""
query_filter = raw_filter if raw_filter is not None else self._generate_filter(where)
if app_id:
query_filter["app_id"] = {"$eq": app_id}
query_vector = self.embedder.embedding_fn([input_query])[0]
query_filter = self._generate_filter(where)
data = self.pinecone_index.query(
vector=query_vector,
filter=query_filter,
@@ -167,16 +166,12 @@ class PineconeDB(BaseVectorDB):
include_metadata=True,
**kwargs,
)
contexts = []
for doc in data.get("matches", []):
metadata = doc.get("metadata", {})
context = metadata.get("text")
if citations:
metadata["score"] = doc.get("score")
contexts.append(tuple((context, metadata)))
else:
contexts.append(context)
return contexts
return [
(metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text")
for doc in data.get("matches", [])
for metadata in [doc.get("metadata", {})]
]
def set_collection_name(self, name: str):
"""