From 22e14b5e65f1dee8dd0c4fc3ca946d0e2d06758d Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Tue, 23 Jan 2024 14:23:57 +0530 Subject: [PATCH] [Bugfix] update zilliz db (#1186) Co-authored-by: Deven Patel --- embedchain/vectordb/zilliz.py | 63 +++++++++++++++++++------------- tests/vectordb/test_zilliz_db.py | 12 ++++-- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index e957cd4d..cb932fa4 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -69,6 +69,7 @@ class ZillizVectorDB(BaseVectorDB): FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=512), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048), FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.embedder.vector_dimension), + FieldSchema(name="metadata", dtype=DataType.JSON), ] schema = CollectionSchema(fields, enable_dynamic_field=True) @@ -94,17 +95,26 @@ class ZillizVectorDB(BaseVectorDB): :return: Existing documents. :rtype: Set[str] """ - if ids is None or len(ids) == 0 or self.collection.num_entities == 0: - return {"ids": []} + data_ids = [] + metadatas = [] + if self.collection.num_entities == 0 or self.collection.is_empty: + return {"ids": data_ids, "metadatas": metadatas} - if not self.collection.is_empty: - filter_ = f"id in {ids}" - results = self.client.query( - collection_name=self.config.collection_name, filter=filter_, output_fields=["id"] - ) - results = [res["id"] for res in results] + filter_ = "" + if ids: + filter_ = f'id in "{ids}"' - return {"ids": set(results)} + if where: + if filter_: + filter_ += " and " + filter_ = f"{self._generate_zilliz_filter(where)}" + + results = self.client.query(collection_name=self.config.collection_name, filter=filter_, output_fields=["*"]) + for res in results: + data_ids.append(res.get("id")) + metadatas.append(res.get("metadata", {})) + + return {"ids": data_ids, "metadatas": metadatas} def add( self, @@ -117,7 +127,7 @@ class ZillizVectorDB(BaseVectorDB): embeddings = self.embedder.embedding_fn(documents) for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings): - data = {**metadata, "id": id, "text": doc, "embeddings": embedding} + data = {"id": id, "text": doc, "embeddings": embedding, "metadata": metadata} self.client.insert(collection_name=self.config.collection_name, data=data, **kwargs) self.collection.load() @@ -128,7 +138,7 @@ class ZillizVectorDB(BaseVectorDB): self, input_query: list[str], n_results: int, - where: dict[str, any], + where: dict[str, Any], citations: bool = False, **kwargs: Optional[dict[str, Any]], ) -> Union[list[tuple[str, dict]], list[str]]: @@ -140,7 +150,7 @@ class ZillizVectorDB(BaseVectorDB): :param n_results: no of similar documents to fetch from database :type n_results: int :param where: to filter data - :type where: str + :type where: dict[str, Any] :raises InvalidDimensionException: Dimensions do not match. :param citations: we use citations boolean param to return context along with the answer. :type citations: bool, default is False. @@ -152,16 +162,15 @@ class ZillizVectorDB(BaseVectorDB): if self.collection.is_empty: return [] - if not isinstance(where, str): - where = None - output_fields = ["*"] input_query_vector = self.embedder.embedding_fn([input_query]) query_vector = input_query_vector[0] + query_filter = self._generate_zilliz_filter(where) query_result = self.client.search( collection_name=self.config.collection_name, data=[query_vector], + filter=query_filter, limit=n_results, output_fields=output_fields, **kwargs, @@ -173,12 +182,10 @@ class ZillizVectorDB(BaseVectorDB): score = query["distance"] context = data["text"] - if "embeddings" in data: - data.pop("embeddings") - if citations: - data["score"] = score - contexts.append(tuple((context, data))) + metadata = data.get("metadata", {}) + metadata["score"] = score + contexts.append(tuple((context, metadata))) else: contexts.append(context) return contexts @@ -216,7 +223,13 @@ class ZillizVectorDB(BaseVectorDB): raise TypeError("Collection name must be a string") self.config.collection_name = name - def delete(self, keys: Union[list, str, int]): + def _generate_zilliz_filter(self, where: dict[str, str]): + operands = [] + for key, value in where.items(): + operands.append(f'(metadata["{key}"] == "{value}")') + return " and ".join(operands) + + def delete(self, where: dict[str, Any]): """ Delete the embeddings from DB. Zilliz only support deleting with keys. @@ -224,7 +237,7 @@ class ZillizVectorDB(BaseVectorDB): :param keys: Primary keys of the table entries to delete. :type keys: Union[list, str, int] """ - self.client.delete( - collection_name=self.config.collection_name, - pks=keys, - ) + data = self.get(where=where) + keys = data.get("ids", []) + if keys: + self.client.delete(collection_name=self.config.collection_name, pks=keys) diff --git a/tests/vectordb/test_zilliz_db.py b/tests/vectordb/test_zilliz_db.py index d4ec4675..e529695f 100644 --- a/tests/vectordb/test_zilliz_db.py +++ b/tests/vectordb/test_zilliz_db.py @@ -130,7 +130,11 @@ class TestZillizDBCollection: [ { "distance": 0.0, - "entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "embeddings": [1, 2, 3]}, + "entity": { + "text": "result_doc", + "embeddings": [1, 2, 3], + "metadata": {"url": "url_1", "doc_id": "doc_id_1"}, + }, } ] ] @@ -141,6 +145,7 @@ class TestZillizDBCollection: mock_search.assert_called_with( collection_name=mock_config.collection_name, data=["query_vector"], + filter="", limit=1, output_fields=["*"], ) @@ -155,10 +160,9 @@ class TestZillizDBCollection: mock_search.assert_called_with( collection_name=mock_config.collection_name, data=["query_vector"], + filter="", limit=1, output_fields=["*"], ) - assert query_result_with_citations == [ - ("result_doc", {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1", "score": 0.0}) - ] + assert query_result_with_citations == [("result_doc", {"url": "url_1", "doc_id": "doc_id_1", "score": 0.0})]