From f50f8a444a888585402bd01cd5fd39fa3ab52712 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Tue, 23 Jan 2024 14:22:58 +0530 Subject: [PATCH] [Bugfix] fix opensearch db (#1184) Co-authored-by: Deven Patel --- embedchain/vectordb/opensearch.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index fe798c0e..e9a4baf1 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -96,9 +96,9 @@ class OpenSearchDB(BaseVectorDB): else: query["query"] = {"bool": {"must": []}} - if "app_id" in where: - app_id = where["app_id"] - query["query"]["bool"]["must"].append({"term": {"metadata.app_id.keyword": app_id}}) + if where: + for key, value in where.items(): + query["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) # OpenSearch syntax is different from Elasticsearch response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit) @@ -176,9 +176,11 @@ class OpenSearchDB(BaseVectorDB): ) pre_filter = {"match_all": {}} # default - if "app_id" in where: - app_id = where["app_id"] - pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}} + if len(where) > 0: + pre_filter = {"bool": {"must": []}} + for key, value in where.items(): + pre_filter["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) + docs = docsearch.similarity_search_with_score( input_query, search_type="script_scoring", @@ -236,10 +238,9 @@ class OpenSearchDB(BaseVectorDB): def delete(self, where): """Deletes a document from the OpenSearch index""" - if "doc_id" not in where: - raise ValueError("doc_id is required to delete a document") - - query = {"query": {"bool": {"must": [{"term": {"metadata.doc_id": where["doc_id"]}}]}}} + query = {"query": {"bool": {"must": []}}} + for key, value in where.items(): + query["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}}) self.client.delete_by_query(index=self._get_index(), body=query) def _get_index(self) -> str: