From 4afef04f261f725333cd0cc0c2e4aa178c8081be Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Tue, 6 Feb 2024 15:42:51 -0800 Subject: [PATCH] [Feature] Add support for metadata filtering on search API (#1245) --- docs/api-reference/app/search.mdx | 100 +++++++++++++++++++------ docs/components/vector-databases.mdx | 4 +- embedchain/app.py | 38 ++-------- embedchain/config/vectordb/pinecone.py | 8 +- embedchain/embedchain.py | 35 +++++++++ embedchain/loaders/json.py | 5 +- embedchain/vectordb/chroma.py | 20 ++++- embedchain/vectordb/pinecone.py | 55 +++++++------- pyproject.toml | 2 +- tests/vectordb/test_pinecone.py | 10 +-- 10 files changed, 173 insertions(+), 104 deletions(-) diff --git a/docs/api-reference/app/search.mdx b/docs/api-reference/app/search.mdx index dd4b0f4e..db4eee1b 100644 --- a/docs/api-reference/app/search.mdx +++ b/docs/api-reference/app/search.mdx @@ -12,6 +12,13 @@ title: '🔍 search' Number of relevant documents to fetch. Defaults to `3` + + Key value pair for metadata filtering. + + + Pass raw filter query based on your vector database. + Currently, `raw_filter` param is only supported for Pinecone vector database. + ### Returns @@ -21,37 +28,84 @@ title: '🔍 search' ## Usage +### Basic + Refer to the following example on how to use the search api: ```python Code example from embedchain import App -# Initialize app app = App() - -# Add data source app.add("https://www.forbes.com/profile/elon-musk") -# Get relevant context using semantic search context = app.search("What is the net worth of Elon?", num_documents=2) print(context) -# Context: -# [ -# { -# 'context': 'Elon Musk PROFILEElon MuskCEO, Tesla$221.9BReal Time Net Worth ...', -# 'metadata': { -# 'source': 'https://www.forbes.com/profile/elon-musk', -# 'document_id': 'some_document_id', -# 'score': 0.404, -# } -# }, -# { -# 'context': 'company, which is now called X.Wealth HistoryHOVER TO REVEAL NET WORTH ...', -# 'metadata': { -# 'source': 'https://www.forbes.com/profile/elon-musk', -# 'document_id': 'some_document_id', -# 'score': 0.435, -# } -# } -# ] +``` + +### Advanced + +#### Metadata filtering using `where` params + +Here is an advanced example of `search()` API with metadata filtering on pinecone database: + +```python +import os + +from embedchain import App + +os.environ["PINECONE_API_KEY"] = "xxx" + +config = { + "vectordb": { + "provider": "pinecone", + "config": { + "metric": "dotproduct", + "vector_dimension": 1536, + "index_name": "ec-test", + "serverless_config": {"cloud": "aws", "region": "us-west-2"}, + }, + } +} + +app = App.from_config(config=config) + +app.add("https://www.forbes.com/profile/bill-gates", metadata={"type": "forbes", "person": "gates"}) +app.add("https://en.wikipedia.org/wiki/Bill_Gates", metadata={"type": "wiki", "person": "gates"}) + +results = app.search("What is the net worth of Bill Gates?", where={"person": "gates"}) +print("Num of search results: ", len(results)) +``` + +#### Metadata filtering using `raw_filter` params + +Following is an example of metadata filtering by passing the raw filter query that pinecone vector database follows: + +```python +import os + +from embedchain import App + +os.environ["PINECONE_API_KEY"] = "xxx" + +config = { + "vectordb": { + "provider": "pinecone", + "config": { + "metric": "dotproduct", + "vector_dimension": 1536, + "index_name": "ec-test", + "serverless_config": {"cloud": "aws", "region": "us-west-2"}, + }, + } +} + +app = App.from_config(config=config) + +app.add("https://www.forbes.com/profile/bill-gates", metadata={"year": 2022, "person": "gates"}) +app.add("https://en.wikipedia.org/wiki/Bill_Gates", metadata={"year": 2024, "person": "gates"}) + +print("Filter with person: gates and year > 2023") +raw_filter = {"$and": [{"person": "gates"}, {"year": {"$gt": 2023}}]} +results = app.search("What is the net worth of Bill Gates?", raw_filter=raw_filter) +print("Num of search results: ", len(results)) ``` diff --git a/docs/components/vector-databases.mdx b/docs/components/vector-databases.mdx index 80c77ab7..8aa48049 100644 --- a/docs/components/vector-databases.mdx +++ b/docs/components/vector-databases.mdx @@ -186,7 +186,7 @@ vectordb: config: metric: cosine vector_dimension: 1536 - collection_name: my-pinecone-index + index_name: my-pinecone-index pod_config: environment: gcp-starter metadata_config: @@ -201,7 +201,7 @@ vectordb: config: metric: cosine vector_dimension: 1536 - collection_name: my-pinecone-index + index_name: my-pinecone-index serverless_config: cloud: aws region: us-west-2 diff --git a/embedchain/app.py b/embedchain/app.py index 966f5b6e..1a46c777 100644 --- a/embedchain/app.py +++ b/embedchain/app.py @@ -11,14 +11,9 @@ import requests import yaml from tqdm import tqdm -from embedchain.cache import ( - Config, - ExactMatchEvaluation, - SearchDistanceEvaluation, - cache, - gptcache_data_manager, - gptcache_pre_function, -) +from embedchain.cache import (Config, ExactMatchEvaluation, + SearchDistanceEvaluation, cache, + gptcache_data_manager, gptcache_pre_function) from embedchain.client import Client from embedchain.config import AppConfig, CacheConfig, ChunkerConfig from embedchain.constants import SQLITE_PATH @@ -26,7 +21,8 @@ from embedchain.embedchain import EmbedChain from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.openai import OpenAIEmbedder from embedchain.evaluation.base import BaseMetric -from embedchain.evaluation.metrics import AnswerRelevance, ContextRelevance, Groundedness +from embedchain.evaluation.metrics import (AnswerRelevance, ContextRelevance, + Groundedness) from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.helpers.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm @@ -254,30 +250,6 @@ class App(EmbedChain): r.raise_for_status() return r.json() - def search(self, query, num_documents=3): - """ - Search for similar documents related to the query in the vector database. - """ - # Send anonymous telemetry - self.telemetry.capture(event_name="search", properties=self._telemetry_props) - - # TODO: Search will call the endpoint rather than fetching the data from the db itself when deploy=True. - if self.id is None: - where = {"app_id": self.local_id} - context = self.db.query( - query, - n_results=num_documents, - where=where, - citations=True, - ) - result = [] - for c in context: - result.append({"context": c[0], "metadata": c[1]}) - return result - else: - # Make API call to the backend to get the results - NotImplementedError("Search is not implemented yet for the prod mode.") - def _upload_file_to_presigned_url(self, presigned_url, file_path): try: with open(file_path, "rb") as file: diff --git a/embedchain/config/vectordb/pinecone.py b/embedchain/config/vectordb/pinecone.py index f377fcba..dbf0f6d1 100644 --- a/embedchain/config/vectordb/pinecone.py +++ b/embedchain/config/vectordb/pinecone.py @@ -9,10 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable class PineconeDBConfig(BaseVectorDbConfig): def __init__( self, - collection_name: Optional[str] = None, - api_key: Optional[str] = None, index_name: Optional[str] = None, - dir: Optional[str] = None, + api_key: Optional[str] = None, vector_dimension: int = 1536, metric: Optional[str] = "cosine", pod_config: Optional[dict[str, any]] = None, @@ -21,9 +19,9 @@ class PineconeDBConfig(BaseVectorDbConfig): ): self.metric = metric self.api_key = api_key + self.index_name = index_name self.vector_dimension = vector_dimension self.extra_params = extra_params - self.index_name = index_name or f"{collection_name}-{vector_dimension}".lower().replace("_", "-") if pod_config is None and serverless_config is None: # If no config is provided, use the default pod spec config pod_environment = os.environ.get("PINECONE_ENV", "gcp-starter") @@ -35,4 +33,4 @@ class PineconeDBConfig(BaseVectorDbConfig): if self.pod_config and self.serverless_config: raise ValueError("Only one of pod_config or serverless_config can be provided.") - super().__init__(collection_name=collection_name, dir=None) + super().__init__(collection_name=self.index_name, dir=None) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 7cf15d1c..3f9a6000 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -634,6 +634,41 @@ class EmbedChain(JSONSerializable): else: return answer + def search(self, query, num_documents=3, where=None, raw_filter=None): + """ + Search for similar documents related to the query in the vector database. + + Args: + query (str): The query to use. + num_documents (int, optional): Number of similar documents to fetch. Defaults to 3. + where (dict[str, any], optional): Filter criteria for the search. + raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search. + + Raises: + ValueError: If both `raw_filter` and `where` are used simultaneously. + + Returns: + list[dict]: A list of dictionaries, each containing the 'context' and 'metadata' of a document. + """ + # Send anonymous telemetry + self.telemetry.capture(event_name="search", properties=self._telemetry_props) + + if raw_filter and where: + raise ValueError("You can't use both `raw_filter` and `where` together.") + + filter_type = "raw_filter" if raw_filter else "where" + filter_criteria = raw_filter if raw_filter else where + + params = { + "input_query": query, + "n_results": num_documents, + "citations": True, + "app_id": self.config.id, + filter_type: filter_criteria, + } + + return [{"context": c[0], "metadata": c[1]} for c in self.db.query(**params)] + def set_collection_name(self, name: str): """ Set the name of the collection. A collection is an isolated space for vectors. diff --git a/embedchain/loaders/json.py b/embedchain/loaders/json.py index 5acbc038..587aa149 100644 --- a/embedchain/loaders/json.py +++ b/embedchain/loaders/json.py @@ -36,7 +36,10 @@ class JSONReader: return ["\n".join(useful_lines)] -VALID_URL_PATTERN = "^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$" +VALID_URL_PATTERN = ( + "^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$" +) + class JSONLoader(BaseLoader): @staticmethod diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 7e17eb55..ec166913 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -79,6 +79,8 @@ class ChromaDB(BaseVectorDB): def _generate_where_clause(where: dict[str, any]) -> dict[str, any]: # If only one filter is supplied, return it as is # (no need to wrap in $and based on chroma docs) + if where is None: + return {} if len(where.keys()) <= 1: return where where_filters = [] @@ -180,9 +182,10 @@ class ChromaDB(BaseVectorDB): self, input_query: list[str], n_results: int, - where: dict[str, any], + where: Optional[dict[str, any]] = None, + raw_filter: Optional[dict[str, any]] = None, citations: bool = False, - **kwargs: Optional[dict[str, Any]], + **kwargs: Optional[dict[str, any]], ) -> Union[list[tuple[str, dict]], list[str]]: """ Query contents from vector database based on vector similarity @@ -193,6 +196,8 @@ class ChromaDB(BaseVectorDB): :type n_results: int :param where: to filter data :type where: dict[str, Any] + :param raw_filter: Raw filter to apply + :type raw_filter: dict[str, Any] :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. :raises InvalidDimensionException: Dimensions do not match. @@ -200,14 +205,21 @@ class ChromaDB(BaseVectorDB): along with url of the source and doc_id (if citations flag is true) :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ + if where and raw_filter: + raise ValueError("Both `where` and `raw_filter` cannot be used together.") + + where_clause = {} + if raw_filter: + where_clause = raw_filter + if where: + where_clause = self._generate_where_clause(where) try: result = self.collection.query( query_texts=[ input_query, ], n_results=n_results, - where=self._generate_where_clause(where), - **kwargs, + where=where_clause, ) except InvalidDimensionException as e: raise InvalidDimensionException( diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index e2f7c6fe..200c1389 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -1,4 +1,3 @@ -import logging import os from typing import Optional, Union @@ -99,10 +98,6 @@ class PineconeDB(BaseVectorDB): batch_existing_ids = list(vectors.keys()) existing_ids.extend(batch_existing_ids) metadatas.extend([vectors.get(ids).get("metadata") for ids in batch_existing_ids]) - - if where is not None: - logging.warning("Filtering is not supported by Pinecone") - return {"ids": existing_ids, "metadatas": metadatas} def add( @@ -122,7 +117,6 @@ class PineconeDB(BaseVectorDB): :type ids: list[str] """ docs = [] - print("Adding documents to Pinecone...") embeddings = self.embedder.embedding_fn(documents) for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings): docs.append( @@ -140,26 +134,31 @@ class PineconeDB(BaseVectorDB): self, input_query: list[str], n_results: int, - where: dict[str, any], + where: Optional[dict[str, any]] = None, + raw_filter: Optional[dict[str, any]] = None, citations: bool = False, + app_id: Optional[str] = None, **kwargs: Optional[dict[str, any]], ) -> Union[list[tuple[str, dict]], list[str]]: """ - query contents from vector database based on vector similarity - :param input_query: list of query string - :type input_query: list[str] - :param n_results: no of similar documents to fetch from database - :type n_results: int - :param where: Optional. to filter data - :type where: dict[str, any] - :param citations: we use citations boolean param to return context along with the answer. - :type citations: bool, default is False. - :return: The content of the document that matched your query, - along with url of the source and doc_id (if citations flag is true) - :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] + Query contents from vector database based on vector similarity. + + Args: + input_query (list[str]): List of query strings. + n_results (int): Number of similar documents to fetch from the database. + where (dict[str, any], optional): Filter criteria for the search. + raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search. + citations (bool, optional): Flag to return context along with metadata. Defaults to False. + app_id (str, optional): Application ID to be passed to Pinecone. + + Returns: + Union[list[tuple[str, dict]], list[str]]: List of document contexts, optionally with metadata. """ + query_filter = raw_filter if raw_filter is not None else self._generate_filter(where) + if app_id: + query_filter["app_id"] = {"$eq": app_id} + query_vector = self.embedder.embedding_fn([input_query])[0] - query_filter = self._generate_filter(where) data = self.pinecone_index.query( vector=query_vector, filter=query_filter, @@ -167,16 +166,12 @@ class PineconeDB(BaseVectorDB): include_metadata=True, **kwargs, ) - contexts = [] - for doc in data.get("matches", []): - metadata = doc.get("metadata", {}) - context = metadata.get("text") - if citations: - metadata["score"] = doc.get("score") - contexts.append(tuple((context, metadata))) - else: - contexts.append(context) - return contexts + + return [ + (metadata.get("text"), {**metadata, "score": doc.get("score")}) if citations else metadata.get("text") + for doc in data.get("matches", []) + for metadata in [doc.get("metadata", {})] + ] def set_collection_name(self, name: str): """ diff --git a/pyproject.toml b/pyproject.toml index fd2f152e..e74a5076 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.73" +version = "0.1.74" description = "Simplest open source retrieval(RAG) framework" authors = [ "Taranjeet Singh ", diff --git a/tests/vectordb/test_pinecone.py b/tests/vectordb/test_pinecone.py index 49319892..472b0593 100644 --- a/tests/vectordb/test_pinecone.py +++ b/tests/vectordb/test_pinecone.py @@ -7,7 +7,7 @@ from embedchain.vectordb.pinecone import PineconeDB @pytest.fixture def pinecone_pod_config(): return PineconeDBConfig( - collection_name="test_collection", + index_name="test_collection", api_key="test_api_key", vector_dimension=3, pod_config={"environment": "test_environment", "metadata_config": {"indexed": ["*"]}}, @@ -17,7 +17,7 @@ def pinecone_pod_config(): @pytest.fixture def pinecone_serverless_config(): return PineconeDBConfig( - collection_name="test_collection", + index_name="test_collection", api_key="test_api_key", vector_dimension=3, serverless_config={ @@ -39,7 +39,7 @@ def test_pinecone_init_without_config(monkeypatch): monkeypatch.delenv("PINECONE_API_KEY") -def test_pinecone_init_with_config(pinecone_pod_config, pinecone_serverless_config, monkeypatch): +def test_pinecone_init_with_config(pinecone_pod_config, monkeypatch): monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._setup_pinecone_index", lambda x: x) monkeypatch.setattr("embedchain.vectordb.pinecone.PineconeDB._get_or_create_db", lambda x: x) pinecone_db = PineconeDB(config=pinecone_pod_config) @@ -158,7 +158,7 @@ def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, m pinecone_db._setup_pinecone_index() assert pinecone_db.client is not None - assert pinecone_db.config.index_name == "test-collection-3" + assert pinecone_db.config.index_name == "test_collection" assert pinecone_db.client.list_indexes().names() == ["test_collection"] assert pinecone_db.pinecone_index is not None @@ -166,7 +166,7 @@ def test_setup_pinecone_index(pinecone_pod_config, pinecone_serverless_config, m pinecone_db._setup_pinecone_index() assert pinecone_db.client is not None - assert pinecone_db.config.index_name == "test-collection-3" + assert pinecone_db.config.index_name == "test_collection" assert pinecone_db.client.list_indexes().names() == ["test_collection"] assert pinecone_db.pinecone_index is not None