[Feature] Add support for metadata filtering on search API (#1245)

This commit is contained in:
Deshraj Yadav
2024-02-06 15:42:51 -08:00
committed by GitHub
parent 8fe2c3effc
commit 4afef04f26
10 changed files with 173 additions and 104 deletions

View File

@@ -12,6 +12,13 @@ title: '🔍 search'
<ParamField path="num_documents" type="int" optional>
Number of relevant documents to fetch. Defaults to `3`
</ParamField>
<ParamField path="where" type="dict" optional>
Key value pair for metadata filtering.
</ParamField>
<ParamField path="raw_filter" type="dict" optional>
Pass raw filter query based on your vector database.
Currently, `raw_filter` param is only supported for Pinecone vector database.
</ParamField>
### Returns
@@ -21,37 +28,84 @@ title: '🔍 search'
## Usage
### Basic
Refer to the following example on how to use the search api:
```python Code example
from embedchain import App
# Initialize app
app = App()
# Add data source
app.add("https://www.forbes.com/profile/elon-musk")
# Get relevant context using semantic search
context = app.search("What is the net worth of Elon?", num_documents=2)
print(context)
# Context:
# [
# {
# 'context': 'Elon Musk PROFILEElon MuskCEO, Tesla$221.9BReal Time Net Worth ...',
# 'metadata': {
# 'source': 'https://www.forbes.com/profile/elon-musk',
# 'document_id': 'some_document_id',
# 'score': 0.404,
# }
# },
# {
# 'context': 'company, which is now called X.Wealth HistoryHOVER TO REVEAL NET WORTH ...',
# 'metadata': {
# 'source': 'https://www.forbes.com/profile/elon-musk',
# 'document_id': 'some_document_id',
# 'score': 0.435,
# }
# }
# ]
```
### Advanced
#### Metadata filtering using `where` params
Here is an advanced example of `search()` API with metadata filtering on pinecone database:
```python
import os
from embedchain import App
os.environ["PINECONE_API_KEY"] = "xxx"
config = {
"vectordb": {
"provider": "pinecone",
"config": {
"metric": "dotproduct",
"vector_dimension": 1536,
"index_name": "ec-test",
"serverless_config": {"cloud": "aws", "region": "us-west-2"},
},
}
}
app = App.from_config(config=config)
app.add("https://www.forbes.com/profile/bill-gates", metadata={"type": "forbes", "person": "gates"})
app.add("https://en.wikipedia.org/wiki/Bill_Gates", metadata={"type": "wiki", "person": "gates"})
results = app.search("What is the net worth of Bill Gates?", where={"person": "gates"})
print("Num of search results: ", len(results))
```
#### Metadata filtering using `raw_filter` params
Following is an example of metadata filtering by passing the raw filter query that pinecone vector database follows:
```python
import os
from embedchain import App
os.environ["PINECONE_API_KEY"] = "xxx"
config = {
"vectordb": {
"provider": "pinecone",
"config": {
"metric": "dotproduct",
"vector_dimension": 1536,
"index_name": "ec-test",
"serverless_config": {"cloud": "aws", "region": "us-west-2"},
},
}
}
app = App.from_config(config=config)
app.add("https://www.forbes.com/profile/bill-gates", metadata={"year": 2022, "person": "gates"})
app.add("https://en.wikipedia.org/wiki/Bill_Gates", metadata={"year": 2024, "person": "gates"})
print("Filter with person: gates and year > 2023")
raw_filter = {"$and": [{"person": "gates"}, {"year": {"$gt": 2023}}]}
results = app.search("What is the net worth of Bill Gates?", raw_filter=raw_filter)
print("Num of search results: ", len(results))
```

View File

@@ -186,7 +186,7 @@ vectordb:
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
index_name: my-pinecone-index
pod_config:
environment: gcp-starter
metadata_config:
@@ -201,7 +201,7 @@ vectordb:
config:
metric: cosine
vector_dimension: 1536
collection_name: my-pinecone-index
index_name: my-pinecone-index
serverless_config:
cloud: aws
region: us-west-2

View File

@@ -11,14 +11,9 @@ import requests
import yaml
from tqdm import tqdm
from embedchain.cache import (
Config,
ExactMatchEvaluation,
SearchDistanceEvaluation,
cache,
gptcache_data_manager,
gptcache_pre_function,
)
from embedchain.cache import (Config, ExactMatchEvaluation,
SearchDistanceEvaluation, cache,
gptcache_data_manager, gptcache_pre_function)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
from embedchain.constants import SQLITE_PATH
@@ -26,7 +21,8 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.evaluation.base import BaseMetric
from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness
from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance,
Groundedness)
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@@ -254,30 +250,6 @@ class App(EmbedChain):
r.raise_for_status()
return r.json()
def search(self, query, num_documents=3):
"""
Search for similar documents related to the query in the vector database.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
# TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True.
if self.id is None:
where = {"app_id": self.local_id}
context = self.db.query(
query,
n_results=num_documents,
where=where,
citations=True,
)
result = []
for c in context:
result.append({"context": c[0], "metadata": c[1]})
return result
else:
# Make API call to the backend to get the results
NotImplementedError("Search is not implemented yet for the prod mode.")
def _upload_file_to_presigned_url(self, presigned_url, file_path):
try:
with open(file_path, "rb") as file:

View File

@@ -9,10 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
class PineconeDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
api_key: Optional[str] = None,
index_name: Optional[str] = None,
dir: Optional[str] = None,
api_key: Optional[str] = None,
vector_dimension: int = 1536,
metric: Optional[str] = "cosine",
pod_config: Optional[dict[str, any]] = None,
@@ -21,9 +19,9 @@ class PineconeDBConfig(BaseVectorDbConfig):
):
self.metric = metric
self.api_key = api_key
self.index_name = index_name
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-")
if pod_config is None and serverless_config is None:
# If no config is provided, use the default pod spec config
pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter")
@@ -35,4 +33,4 @@ class PineconeDBConfig(BaseVectorDbConfig):
if self.pod_config and self.serverless_config:
raise ValueError("Only one of pod_config or serverless_config can be provided.")
super().__init__(collection_name=collection_name, dir=None)
super().__init__(collection_name=self.index_name, dir=None)

View File

@@ -634,6 +634,41 @@ class EmbedChain(JSONSerializable):
else:
return answer
def search(self, query, num_documents=3, where=None, raw_filter=None):
"""
Search for similar documents related to the query in the vector database.
Args:
query (str): The query to use.
num_documents (int, optional): Number of similar documents to fetch. Defaults to 3.
where (dict[str, any], optional): Filter criteria for the search.
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
Raises:
ValueError: If both `raw_filter` and `where` are used simultaneously.
Returns:
list[dict]: A list of dictionaries, each containing the 'context' and 'metadata' of a document.
"""
# Send anonymous telemetry
self.telemetry.capture(event_name="search", properties=self._telemetry_props)
if raw_filter and where:
raise ValueError("You can't use both `raw_filter` and `where` together.")
filter_type = "raw_filter" if raw_filter else "where"
filter_criteria = raw_filter if raw_filter else where
params = {
"input_query": query,
"n_results": num_documents,
"citations": True,
"app_id": self.config.id,
filter_type: filter_criteria,
}
return [{"context": c[0], "metadata": c[1]} for c in self.db.query(**params)]
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.

View File

@@ -36,7 +36,10 @@ class JSONReader:
return ["\n".join(useful_lines)]
VALID_URL_PATTERN = "^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$"
VALID_URL_PATTERN = (
"^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$"
)
class JSONLoader(BaseLoader):
@staticmethod

View File

@@ -79,6 +79,8 @@ class ChromaDB(BaseVectorDB):
def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
# If only one filter is supplied, return it as is
# (no need to wrap in $and based on chroma docs)
if where is None:
return {}
if len(where.keys()) <= 1:
return where
where_filters = []
@@ -180,9 +182,10 @@ class ChromaDB(BaseVectorDB):
self,
input_query: list[str],
n_results: int,
where: dict[str, any],
where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False,
**kwargs: Optional[dict[str, Any]],
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
Query contents from vector database based on vector similarity
@@ -193,6 +196,8 @@ class ChromaDB(BaseVectorDB):
:type n_results: int
:param where: to filter data
:type where: dict[str, Any]
:param raw_filter: Raw filter to apply
:type raw_filter: dict[str, Any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match.
@@ -200,14 +205,21 @@ class ChromaDB(BaseVectorDB):
along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
"""
if where and raw_filter:
raise ValueError("Both `where` and `raw_filter` cannot be used together.")
where_clause = {}
if raw_filter:
where_clause = raw_filter
if where:
where_clause = self._generate_where_clause(where)
try:
result = self.collection.query(
query_texts=[
input_query,
],
n_results=n_results,
where=self._generate_where_clause(where),
**kwargs,
where=where_clause,
)
except InvalidDimensionException as e:
raise InvalidDimensionException(

View File

@@ -1,4 +1,3 @@
import logging
import os
from typing import Optional, Union
@@ -99,10 +98,6 @@ class PineconeDB(BaseVectorDB):
batch_existing_ids = list(vectors.keys())
existing_ids.extend(batch_existing_ids)
metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids])
if where is not None:
logging.warning("Filtering is not supported by Pinecone")
return {"ids": existing_ids, "metadatas": metadatas}
def add(
@@ -122,7 +117,6 @@ class PineconeDB(BaseVectorDB):
:type ids: list[str]
"""
docs = []
print("Adding documents to Pinecone...")
embeddings = self.embedder.embedding_fn(documents)
for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
docs.append(
@@ -140,26 +134,31 @@ class PineconeDB(BaseVectorDB):
self,
input_query: list[str],
n_results: int,
where: dict[str, any],
where: Optional[dict[str, any]] = None,
raw_filter: Optional[dict[str, any]] = None,
citations: bool = False,
app_id: Optional[str] = None,
**kwargs: Optional[dict[str, any]],
) -> Union[list[tuple[str, dict]], list[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
:type input_query: list[str]
:param n_results: no of similar documents to fetch from database
:type n_results: int
:param where: Optional. to filter data
:type where: dict[str, any]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
Query contents from vector database based on vector similarity.
Args:
input_query (list[str]): List of query strings.
n_results (int): Number of similar documents to fetch from the database.
where (dict[str, any], optional): Filter criteria for the search.
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
citations (bool, optional): Flag to return context along with metadata. Defaults to False.
app_id (str, optional): Application ID to be passed to Pinecone.
Returns:
Union[list[tuple[str, dict]], list[str]]: List of document contexts, optionally with metadata.
"""
query_filter = raw_filter if raw_filter is not None else self._generate_filter(where)
if app_id:
query_filter["app_id"] = {"$eq": app_id}
query_vector = self.embedder.embedding_fn([input_query])[0]
query_filter = self._generate_filter(where)
data = self.pinecone_index.query(
vector=query_vector,
filter=query_filter,
@@ -167,16 +166,12 @@ class PineconeDB(BaseVectorDB):
include_metadata=True,
**kwargs,
)
contexts = []
for doc in data.get("matches", []):
metadata = doc.get("metadata", {})
context = metadata.get("text")
if citations:
metadata["score"] = doc.get("score")
contexts.append(tuple((context, metadata)))
else:
contexts.append(context)
return contexts
return [
(metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text")
for doc in data.get("matches", [])
for metadata in [doc.get("metadata", {})]
]
def set_collection_name(self, name: str):
"""

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "embedchain"
version = "0.1.73"
version = "0.1.74"
description = "Simplest open source retrieval(RAG) framework"
authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -7,7 +7,7 @@ from embedchain.vectordb.pinecone import PineconeDB
@pytest.fixture
def pinecone_pod_config():
return PineconeDBConfig(
collection_name="test_collection",
index_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}},
@@ -17,7 +17,7 @@ def pinecone_pod_config():
@pytest.fixture
def pinecone_serverless_config():
return PineconeDBConfig(
collection_name="test_collection",
index_name="test_collection",
api_key="test_api_key",
vector_dimension=3,
serverless_config={
@@ -39,7 +39,7 @@ def test_pinecone_init_without_config(monkeypatch):
monkeypatch.delenv("PINECONE_API_KEY")
def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch):
def test_pinecone_init_with_config(pinecone_pod_config, monkeypatch):
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x)
monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x)
pinecone_db = PineconeDB(config=pinecone_pod_config)
@@ -158,7 +158,7 @@ def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, m
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.config.index_name == "test_collection"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None
@@ -166,7 +166,7 @@ def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, m
pinecone_db._setup_pinecone_index()
assert pinecone_db.client is not None
assert pinecone_db.config.index_name == "test-collection-3"
assert pinecone_db.config.index_name == "test_collection"
assert pinecone_db.client.list_indexes().names() == ["test_collection"]
assert pinecone_db.pinecone_index is not None