[Feature] Add support for metadata filtering on search API (#1245)

This commit is contained in:
Deshraj Yadav
2024-02-06 15:42:51 -08:00
committed by GitHub
parent 8fe2c3effc
commit 4afef04f26
10 changed files with 173 additions and 104 deletions

View File

@@ -11,14 +11,9 @@ import requests
import yaml
from tqdm import tqdm
from embedchain.cache import (
Config,
ExactMatchEvaluation,
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.cache import (Config, ExactMatchEvaluation,
SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH
@@ -26,7 +21,8 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@@ -254,30 +250,6 @@ class App(EmbedChain):
r.raise_for_status()
return r.json()
def search(self, query, num_documents=3):
"""
Search for similar documents related to the query in the vector database.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
if self.id is None:
where = {"app_id": self.local_id}
context = self.db.query(
query,
n_results=num_documents,
where=where,
citations=True,
)
result = []
for c in context:
result.append({"context": c[0], "metadata": c[1]})
return result
else:
# Make API call to the backend to get the results
NotImplementedError("Search is not implemented yet for the prod mode.")
def _upload_file_to_presigned_url(self, presigned_url, file_path):
try:
with open(file_path, "rb") as file:

View File

@@ -9,10 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
api_key: Optional[str] = None,
index_name: Optional[str] = None,
dir: Optional[str] = None,
api_key: Optional[str] = None,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
pod_config: Optional[dict[str, any]] = None,
@@ -21,9 +19,9 @@ class PineconeDBConfig(BaseVectorDbConfig):
):
self.metric = metric
self.api_key = api_key
self.index_name = index_name
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
@@ -35,4 +33,4 @@ class PineconeDBConfig(BaseVectorDbConfig):
if self.pod_config and self.serverless_config:
raise ValueError("Only one of pod_config or serverless_config can be provided.")
super().__init__(collection_name=collection_name, dir=None)
super().__init__(collection_name=self.index_name, dir=None)

View File

@@ -634,6 +634,41 @@ class EmbedChain(JSONSerializable):
else:
return answer
def search(self, query, num_documents=3, where=None, raw_filter=None):
"""
Search for similar documents related to the query in the vector database.
Args:
query (str): The query to use.
num_documents (int, optional): Number of similar documents to fetch. Defaults to 3.
where (dict[str, any], optional): Filter criteria for the search.
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
Raises:
ValueError: If both `raw_filter` and `where` are used simultaneously.
Returns:
list[dict]: A list of dictionaries, each containing the 'context' and 'metadata' of a document.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
if raw_filter and where:
raise ValueError("You can't use both `raw_filter` and `where` together.")
filter_type = "raw_filter" if raw_filter else "where"
filter_criteria = raw_filter if raw_filter else where
params = {
"input_query": query,
"n_results": num_documents,
"citations": True,
"app_id": self.config.id,
filter_type: filter_criteria,
}
return [{"context": c[0], "metadata": c[1]} for c in self.db.query(**params)]
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.

View File

@@ -36,7 +36,10 @@ class JSONReader:
return ["\n".join(useful_lines)]
VALID_URL_PATTERN = "^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$"
VALID_URL_PATTERN = (
"^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$"
)
class JSONLoader(BaseLoader):
@staticmethod

View File

@@ -79,6 +79,8 @@ class ChromaDB(BaseVectorDB):
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if where is None:
return {}
if len(where.keys()) <= 1:
return where
where_filters = []
@@ -180,9 +182,10 @@ class ChromaDB(BaseVectorDB):
self,
input_query: list[str],
n_results: int,
where: dict[str, any],
where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False,
**kwargs: Optional[dict[str, Any]],
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
Query contents from vector database based on vector similarity
@@ -193,6 +196,8 @@ class ChromaDB(BaseVectorDB):
:type n_results: int
:param where: to filter data
:type where: dict[str, Any]
:param raw_filter: Raw filter to apply
:type raw_filter: dict[str, Any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match.
@@ -200,14 +205,21 @@ class ChromaDB(BaseVectorDB):
along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
if where and raw_filter:
raise ValueError("Both `where` and `raw_filter` cannot be used together.")
where_clause = {}
if raw_filter:
where_clause = raw_filter
if where:
where_clause = self._generate_where_clause(where)
try:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=self._generate_where_clause(where),
**kwargs,
where=where_clause,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(

View File

@@ -1,4 +1,3 @@
import logging
import os
from typing import Optional, Union
@@ -99,10 +98,6 @@ class PineconeDB(BaseVectorDB):
batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids)
metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
if where is not None:
logging.warning("Filtering is not supported by Pinecone")
return {"ids": existing_ids, "metadatas": metadatas}
def add(
@@ -122,7 +117,6 @@ class PineconeDB(BaseVectorDB):
:type ids: list[str]
"""
docs = []
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
docs.append(
@@ -140,26 +134,31 @@ class PineconeDB(BaseVectorDB):
self,
input_query: list[str],
n_results: int,
where: dict[str, any],
where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False,
app_id: Optional[str] = None,
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
Query contents from vector database based on vector similarity.
Args:
input_query (list[str]): List of query strings.
n_results (int): Number of similar documents to fetch from the database.
where (dict[str, any], optional): Filter criteria for the search.
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
citations (bool, optional): Flag to return context along with metadata. Defaults to False.
app_id (str, optional): Application ID to be passed to Pinecone.
Returns:
Union[list[tuple[str, dict]], list[str]]: List of document contexts, optionally with metadata.
"""
query_filter = raw_filter if raw_filter is not None else self._generate_filter(where)
if app_id:
query_filter["app_id"] = {"$eq": app_id}
query_vector = self.embedder.embedding_fn([input_query])[0]
query_filter = self._generate_filter(where)
data = self.pinecone_index.query(
vector=query_vector,
filter=query_filter,
@@ -167,16 +166,12 @@ class PineconeDB(BaseVectorDB):
include_metadata=True,
**kwargs,
)
contexts = []
for doc in data.get("matches", []):
metadata = doc.get("metadata", {})
context = metadata.get("text")
if citations:
metadata["score"] = doc.get("score")
contexts.append(tuple((context, metadata)))
else:
contexts.append(context)
return contexts
return [
(metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text")
for doc in data.get("matches", [])
for metadata in [doc.get("metadata", {})]
]
def set_collection_name(self, name: str):
"""