[Feature] Add support for metadata filtering on search API (#1245)

This commit is contained in:
Deshraj Yadav
2024-02-06 15:42:51 -08:00
committed by GitHub
parent 8fe2c3effc
commit 4afef04f26
10 changed files with 173 additions and 104 deletions

View File

@@ -12,6 +12,13 @@ title: '🔍 search'
<ParamField path="num_documents" type="int" optional> <ParamField path="num_documents" type="int" optional>
Number of relevant documents to fetch. Defaults to `3` Number of relevant documents to fetch. Defaults to `3`
</ParamField> </ParamField>
<ParamField path="where" type="dict" optional>
Key value pair for metadata filtering.
</ParamField>
<ParamField path="raw_filter" type="dict" optional>
Pass raw filter query based on your vector database.
Currently, `raw_filter` param is only supported for Pinecone vector database.
</ParamField>
### Returns ### Returns
@@ -21,37 +28,84 @@ title: '🔍 search'
## Usage ## Usage
### Basic
Refer to the following example on how to use the search api: Refer to the following example on how to use the search api:
```python Code example ```python Code example
from embedchain import App from embedchain import App
# Initialize app
app = App() app = App()
# Add data source
app.add("https://www.forbes.com/profile/elon-musk") app.add("https://www.forbes.com/profile/elon-musk")
# Get relevant context using semantic search
context = app.search("What is the net worth of Elon?", num_documents=2) context = app.search("What is the net worth of Elon?", num_documents=2)
print(context) print(context)
# Context: ```
# [
# { ### Advanced
# 'context': 'Elon Musk PROFILEElon MuskCEO, Tesla$221.9BReal Time Net Worth ...',
# 'metadata': { #### Metadata filtering using `where` params
# 'source': 'https://www.forbes.com/profile/elon-musk',
# 'document_id': 'some_document_id', Here is an advanced example of `search()` API with metadata filtering on pinecone database:
# 'score': 0.404,
# } ```python
# }, import os
# {
# 'context': 'company, which is now called X.Wealth HistoryHOVER TO REVEAL NET WORTH ...', from embedchain import App
# 'metadata': {
# 'source': 'https://www.forbes.com/profile/elon-musk', os.environ["PINECONE_API_KEY"] = "xxx"
# 'document_id': 'some_document_id',
# 'score': 0.435, config = {
# } "vectordb": {
# } "provider": "pinecone",
# ] "config": {
"metric": "dotproduct",
"vector_dimension": 1536,
"index_name": "ec-test",
"serverless_config": {"cloud": "aws", "region": "us-west-2"},
},
}
}
app = App.from_config(config=config)
app.add("https://www.forbes.com/profile/bill-gates", metadata={"type": "forbes", "person": "gates"})
app.add("https://en.wikipedia.org/wiki/Bill_Gates", metadata={"type": "wiki", "person": "gates"})
results = app.search("What is the net worth of Bill Gates?", where={"person": "gates"})
print("Num of search results: ", len(results))
```
#### Metadata filtering using `raw_filter` params
Following is an example of metadata filtering by passing the raw filter query that pinecone vector database follows:
```python
import os
from embedchain import App
os.environ["PINECONE_API_KEY"] = "xxx"
config = {
"vectordb": {
"provider": "pinecone",
"config": {
"metric": "dotproduct",
"vector_dimension": 1536,
"index_name": "ec-test",
"serverless_config": {"cloud": "aws", "region": "us-west-2"},
},
}
}
app = App.from_config(config=config)
app.add("https://www.forbes.com/profile/bill-gates", metadata={"year": 2022, "person": "gates"})
app.add("https://en.wikipedia.org/wiki/Bill_Gates", metadata={"year": 2024, "person": "gates"})
print("Filter with person: gates and year > 2023")
raw_filter = {"$and": [{"person": "gates"}, {"year": {"$gt": 2023}}]}
results = app.search("What is the net worth of Bill Gates?", raw_filter=raw_filter)
print("Num of search results: ", len(results))
``` ```

View File

@@ -186,7 +186,7 @@ vectordb:
config: config:
metric: cosine metric: cosine
vector_dimension: 1536 vector_dimension: 1536
collection_name: my-pinecone-index index_name: my-pinecone-index
pod_config: pod_config:
environment: gcp-starter environment: gcp-starter
metadata_config: metadata_config:
@@ -201,7 +201,7 @@ vectordb:
config: config:
metric: cosine metric: cosine
vector_dimension: 1536 vector_dimension: 1536
collection_name: my-pinecone-index index_name: my-pinecone-index
serverless_config: serverless_config:
cloud: aws cloud: aws
region: us-west-2 region: us-west-2

View File

@@ -11,14 +11,9 @@ import requests
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
from embedchain.cache import ( from embedchain.cache import (Config, ExactMatchEvaluation,
Config, SearchDistanceEvaluation, cache,
ExactMatchEvaluation, gptcache_data_manager, gptcache_pre_function)
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.client import Client from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH from embedchain.constants import SQLITE_PATH
@@ -26,7 +21,8 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
@@ -254,30 +250,6 @@ class App(EmbedChain):
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
def search(self, query, num_documents=3):
"""
Search for similar documents related to the query in the vector database.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
if self.id is None:
where = {"app_id": self.local_id}
context = self.db.query(
query,
n_results=num_documents,
where=where,
citations=True,
)
result = []
for c in context:
result.append({"context": c[0], "metadata": c[1]})
return result
else:
# Make API call to the backend to get the results
NotImplementedError("Search is not implemented yet for the prod mode.")
def _upload_file_to_presigned_url(self, presigned_url, file_path): def _upload_file_to_presigned_url(self, presigned_url, file_path):
try: try:
with open(file_path, "rb") as file: with open(file_path, "rb") as file:

View File

@@ -9,10 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
class PineconeDBConfig(BaseVectorDbConfig): class PineconeDBConfig(BaseVectorDbConfig):
def __init__( def __init__(
self, self,
collection_name: Optional[str] = None,
api_key: Optional[str] = None,
index_name: Optional[str] = None, index_name: Optional[str] = None,
dir: Optional[str] = None, api_key: Optional[str] = None,
vector_dimension: int = 1536, vector_dimension: int = 1536,
metric: Optional[str] = "cosine", metric: Optional[str] = "cosine",
pod_config: Optional[dict[str, any]] = None, pod_config: Optional[dict[str, any]] = None,
@@ -21,9 +19,9 @@ class PineconeDBConfig(BaseVectorDbConfig):
): ):
self.metric = metric self.metric = metric
self.api_key = api_key self.api_key = api_key
self.index_name = index_name
self.vector_dimension = vector_dimension self.vector_dimension = vector_dimension
self.extra_params = extra_params self.extra_params = extra_params
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
if pod_config is None and serverless_config is None: if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config # If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter") pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
@@ -35,4 +33,4 @@ class PineconeDBConfig(BaseVectorDbConfig):
if self.pod_config and self.serverless_config: if self.pod_config and self.serverless_config:
raise ValueError("Only one of pod_config or serverless_config can be provided.") raise ValueError("Only one of pod_config or serverless_config can be provided.")
super().__init__(collection_name=collection_name, dir=None) super().__init__(collection_name=self.index_name, dir=None)

View File

@@ -634,6 +634,41 @@ class EmbedChain(JSONSerializable):
else: else:
return answer return answer
def search(self, query, num_documents=3, where=None, raw_filter=None):
"""
Search for similar documents related to the query in the vector database.
Args:
query (str): The query to use.
num_documents (int, optional): Number of similar documents to fetch. Defaults to 3.
where (dict[str, any], optional): Filter criteria for the search.
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
Raises:
ValueError: If both `raw_filter` and `where` are used simultaneously.
Returns:
list[dict]: A list of dictionaries, each containing the 'context' and 'metadata' of a document.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
if raw_filter and where:
raise ValueError("You can't use both `raw_filter` and `where` together.")
filter_type = "raw_filter" if raw_filter else "where"
filter_criteria = raw_filter if raw_filter else where
params = {
"input_query": query,
"n_results": num_documents,
"citations": True,
"app_id": self.config.id,
filter_type: filter_criteria,
}
return [{"context": c[0], "metadata": c[1]} for c in self.db.query(**params)]
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """
Set the name of the collection. A collection is an isolated space for vectors. Set the name of the collection. A collection is an isolated space for vectors.

View File

@@ -36,7 +36,10 @@ class JSONReader:
return ["\n".join(useful_lines)] return ["\n".join(useful_lines)]
VALID_URL_PATTERN = "^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$" VALID_URL_PATTERN = (
"^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$"
)
class JSONLoader(BaseLoader): class JSONLoader(BaseLoader):
@staticmethod @staticmethod

View File

@@ -79,6 +79,8 @@ class ChromaDB(BaseVectorDB):
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]: def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
# If only one filter is supplied, return it as is # If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs) # (no need to wrap in $and based on chroma docs)
if where is None:
return {}
if len(where.keys()) <= 1: if len(where.keys()) <= 1:
return where return where
where_filters = [] where_filters = []
@@ -180,9 +182,10 @@ class ChromaDB(BaseVectorDB):
self, self,
input_query: list[str], input_query: list[str],
n_results: int, n_results: int,
where: dict[str, any], where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False, citations: bool = False,
**kwargs: Optional[dict[str, Any]], **kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
Query contents from vector database based on vector similarity Query contents from vector database based on vector similarity
@@ -193,6 +196,8 @@ class ChromaDB(BaseVectorDB):
:type n_results: int :type n_results: int
:param where: to filter data :param where: to filter data
:type where: dict[str, Any] :type where: dict[str, Any]
:param raw_filter: Raw filter to apply
:type raw_filter: dict[str, Any]
:param citations: we use citations boolean param to return context along with the answer. :param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False. :type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match. :raises InvalidDimensionException: Dimensions do not match.
@@ -200,14 +205,21 @@ class ChromaDB(BaseVectorDB):
along with url of the source and doc_id (if citations flag is true) along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
""" """
if where and raw_filter:
raise ValueError("Both `where` and `raw_filter` cannot be used together.")
where_clause = {}
if raw_filter:
where_clause = raw_filter
if where:
where_clause = self._generate_where_clause(where)
try: try:
result = self.collection.query( result = self.collection.query(
query_texts=[ query_texts=[
input_query, input_query,
], ],
n_results=n_results, n_results=n_results,
where=self._generate_where_clause(where), where=where_clause,
**kwargs,
) )
except InvalidDimensionException as e: except InvalidDimensionException as e:
raise InvalidDimensionException( raise InvalidDimensionException(

View File

@@ -1,4 +1,3 @@
import logging
import os import os
from typing import Optional, Union from typing import Optional, Union
@@ -99,10 +98,6 @@ class PineconeDB(BaseVectorDB):
batch_existing_ids = list(vectors.keys()) batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids) existing_ids.extend(batch_existing_ids)
metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids]) metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
if where is not None:
logging.warning("Filtering is not supported by Pinecone")
return {"ids": existing_ids, "metadatas": metadatas} return {"ids": existing_ids, "metadatas": metadatas}
def add( def add(
@@ -122,7 +117,6 @@ class PineconeDB(BaseVectorDB):
:type ids: list[str] :type ids: list[str]
""" """
docs = [] docs = []
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents) embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
docs.append( docs.append(
@@ -140,26 +134,31 @@ class PineconeDB(BaseVectorDB):
self, self,
input_query: list[str], input_query: list[str],
n_results: int, n_results: int,
where: dict[str, any], where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False, citations: bool = False,
app_id: Optional[str] = None,
**kwargs: Optional[dict[str, any]], **kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]: ) -> Union[list[tuple[str, dict]], list[str]]:
""" """
query contents from vector database based on vector similarity Query contents from vector database based on vector similarity.
:param input_query: list of query string
:type input_query: list[str] Args:
:param n_results: no of similar documents to fetch from database input_query (list[str]): List of query strings.
:type n_results: int n_results (int): Number of similar documents to fetch from the database.
:param where: Optional. to filter data where (dict[str, any], optional): Filter criteria for the search.
:type where: dict[str, any] raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
:param citations: we use citations boolean param to return context along with the answer. citations (bool, optional): Flag to return context along with metadata. Defaults to False.
:type citations: bool, default is False. app_id (str, optional): Application ID to be passed to Pinecone.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true) Returns:
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] Union[list[tuple[str, dict]], list[str]]: List of document contexts, optionally with metadata.
""" """
query_filter = raw_filter if raw_filter is not None else self._generate_filter(where)
if app_id:
query_filter["app_id"] = {"$eq": app_id}
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
query_filter = self._generate_filter(where)
data = self.pinecone_index.query( data = self.pinecone_index.query(
vector=query_vector, vector=query_vector,
filter=query_filter, filter=query_filter,
@@ -167,16 +166,12 @@ class PineconeDB(BaseVectorDB):
include_metadata=True, include_metadata=True,
**kwargs, **kwargs,
) )
contexts = []
for doc in data.get("matches", []): return [
metadata = doc.get("metadata", {}) (metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text")
context = metadata.get("text") for doc in data.get("matches", [])
if citations: for metadata in [doc.get("metadata", {})]
metadata["score"] = doc.get("score") ]
contexts.append(tuple((context, metadata)))
else:
contexts.append(context)
return contexts
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.73" version = "0.1.74"
description = "Simplest open source retrieval(RAG) framework" description = "Simplest open source retrieval(RAG) framework"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -7,7 +7,7 @@ from embedchain.vectordb.pinecone import PineconeDB
@pytest.fixture @pytest.fixture
def pinecone_pod_config(): def pinecone_pod_config():
return PineconeDBConfig( return PineconeDBConfig(
collection_name="test_collection", index_name="test_collection",
api_key="test_api_key", api_key="test_api_key",
vector_dimension=3, vector_dimension=3,
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}}, pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
@@ -17,7 +17,7 @@ def pinecone_pod_config():
@pytest.fixture @pytest.fixture
def pinecone_serverless_config(): def pinecone_serverless_config():
return PineconeDBConfig( return PineconeDBConfig(
collection_name="test_collection", index_name="test_collection",
api_key="test_api_key", api_key="test_api_key",
vector_dimension=3, vector_dimension=3,
serverless_config={ serverless_config={
@@ -39,7 +39,7 @@ def test_pinecone_init_without_config(monkeypatch):
monkeypatch.delenv("PINECONE_API_KEY") monkeypatch.delenv("PINECONE_API_KEY")
def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch): def test_pinecone_init_with_config(pinecone_pod_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x) monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x) monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB(config=pinecone_pod_config) pinecone_db = PineconeDB(config=pinecone_pod_config)
@@ -158,7 +158,7 @@ def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, m
pinecone_db._setup_pinecone_index() pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3" assert pinecone_db.config.index_name == "test_collection"
assert pinecone_db.client.list_indexes().names() == ["test_collection"] assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None assert pinecone_db.pinecone_index is not None
@@ -166,7 +166,7 @@ def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, m
pinecone_db._setup_pinecone_index() pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3" assert pinecone_db.config.index_name == "test_collection"
assert pinecone_db.client.list_indexes().names() == ["test_collection"] assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None assert pinecone_db.pinecone_index is not None