From d77e8da3f3f7bf3490274afd3c3999e60395926d Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Wed, 25 Oct 2023 22:20:32 -0700 Subject: [PATCH] [Feature] Update `db.query` to return source of context (#831) --- .gitignore | 1 + embedchain/embedchain.py | 8 ++++-- embedchain/factory.py | 12 ++++---- embedchain/vectordb/chroma.py | 20 +++++++++---- embedchain/vectordb/elasticsearch.py | 24 +++++++++++----- embedchain/vectordb/opensearch.py | 18 ++++++++---- embedchain/vectordb/pinecone.py | 22 +++++++++----- embedchain/vectordb/qdrant.py | 29 +++++++++++++++---- embedchain/vectordb/weaviate.py | 22 +++++++------- embedchain/vectordb/zilliz.py | 24 ++++++++++------ tests/vectordb/test_chroma_db.py | 38 +++++++++++++++++++++++-- tests/vectordb/test_elasticsearch_db.py | 38 +++++++++++++++++++------ tests/vectordb/test_zilliz_db.py | 12 ++++---- 13 files changed, 195 insertions(+), 73 deletions(-) diff --git a/.gitignore b/.gitignore index bde609f0..b6df6ac0 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,5 @@ test-db notebooks/*.yaml .ipynb_checkpoints/ + !configs/*.yaml diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 35c53139..106366df 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -500,13 +500,17 @@ class EmbedChain(JSONSerializable): db_query = ClipProcessor.get_text_features(query=input_query) - contents = self.db.query( + contexts = self.db.query( input_query=db_query, n_results=query_config.number_documents, where=where, skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), ) - return contents + + if len(contexts) > 0 and isinstance(contexts[0], tuple): + contexts = list(map(lambda x: x[0], contexts)) + + return contexts def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str: """ diff --git a/embedchain/factory.py b/embedchain/factory.py index 97453144..66268855 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -41,15 +41,15 @@ class LlmFactory: class EmbedderFactory: provider_to_class = { + "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder", "gpt4all": "embedchain.embedder.gpt4all.GPT4AllEmbedder", "huggingface": "embedchain.embedder.huggingface.HuggingFaceEmbedder", - "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder", - "azure_openai": "embedchain.embedder.openai.OpenAIEmbedder", "openai": "embedchain.embedder.openai.OpenAIEmbedder", + "vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder", } provider_to_config_class = { - "openai": "embedchain.config.embedder.base.BaseEmbedderConfig", "azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig", + "openai": "embedchain.config.embedder.base.BaseEmbedderConfig", } @classmethod @@ -72,16 +72,18 @@ class VectorDBFactory: "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB", "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB", "pinecone": "embedchain.vectordb.pinecone.PineconeDB", - "weaviate": "embedchain.vectordb.weaviate.WeaviateDB", "qdrant": "embedchain.vectordb.qdrant.QdrantDB", + "weaviate": "embedchain.vectordb.weaviate.WeaviateDB", + "zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB", } provider_to_config_class = { "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig", "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig", "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig", "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig", - "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", "qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig", + "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig", + "zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig", } @classmethod diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index b5e65c88..c8d2194d 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from chromadb import Collection, QueryResult from langchain.docstore.document import Document @@ -191,7 +191,9 @@ class ChromaDB(BaseVectorDB): ) ] - def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + def query( + self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool + ) -> List[Tuple[str, str, str]]: """ Query contents from vector database based on vector similarity @@ -204,8 +206,8 @@ class ChromaDB(BaseVectorDB): :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :type skip_embedding: bool :raises InvalidDimensionException: Dimensions do not match. - :return: The content of the document that matched your query. - :rtype: List[str] + :return: The content of the document that matched your query, url of the source, doc_id + :rtype: List[Tuple[str,str,str]] """ try: if skip_embedding: @@ -231,8 +233,14 @@ class ChromaDB(BaseVectorDB): " embeddings, is used to retrieve an embedding from the database." ) from None results_formatted = self._format_result(result) - contents = [result[0].page_content for result in results_formatted] - return contents + contexts = [] + for result in results_formatted: + context = result[0].page_content + metadata = result[0].metadata + source = metadata["url"] + doc_id = metadata["doc_id"] + contexts.append((context, source, doc_id)) + return contexts def set_collection_name(self, name: str): """ diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index f3b2d293..aeb627d5 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple try: from elasticsearch import Elasticsearch @@ -135,7 +135,9 @@ class ElasticsearchDB(BaseVectorDB): bulk(self.client, docs) self.client.indices.refresh(index=self._get_index()) - def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + def query( + self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool + ) -> List[Tuple[str, str, str]]: """ query contents from vector data base based on vector similarity @@ -147,8 +149,9 @@ class ElasticsearchDB(BaseVectorDB): :type where: Dict[str, any] :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :type skip_embedding: bool - :return: Database contents that are the result of the query - :rtype: List[str] + :return: The context of the document that matched your query, url of the source, doc_id + + :rtype: List[Tuple[str,str,str]] """ if skip_embedding: query_vector = input_query @@ -156,6 +159,7 @@ class ElasticsearchDB(BaseVectorDB): input_query_vector = self.embedder.embedding_fn(input_query) query_vector = input_query_vector[0] + # `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html` query = { "script_score": { "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, @@ -167,11 +171,17 @@ class ElasticsearchDB(BaseVectorDB): } if "app_id" in where: app_id = where["app_id"] - query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}] - _source = ["text"] + query["script_score"]["query"] = {"match": {"metadata.app_id": app_id}} + _source = ["text", "metadata.url", "metadata.doc_id"] response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results) docs = response["hits"]["hits"] - contents = [doc["_source"]["text"] for doc in docs] + contents = [] + for doc in docs: + context = doc["_source"]["text"] + metadata = doc["_source"]["metadata"] + source = metadata["url"] + doc_id = metadata["doc_id"] + contents.append(tuple((context, source, doc_id))) return contents def set_collection_name(self, name: str): diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 3857300b..1ba29e82 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple try: from opensearchpy import OpenSearch @@ -145,7 +145,9 @@ class OpenSearchDB(BaseVectorDB): bulk(self.client, docs) self.client.indices.refresh(index=self._get_index()) - def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + def query( + self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool + ) -> List[Tuple[str, str, str]]: """ query contents from vector data base based on vector similarity @@ -157,8 +159,8 @@ class OpenSearchDB(BaseVectorDB): :type where: Dict[str, any] :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :type skip_embedding: bool - :return: Database contents that are the result of the query - :rtype: List[str] + :return: The content of the document that matched your query, url of the source, doc_id + :rtype: List[Tuple[str,str,str]] """ # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists embeddings = OpenAIEmbeddings() @@ -185,7 +187,13 @@ class OpenSearchDB(BaseVectorDB): pre_filter=pre_filter, k=n_results, ) - contents = [doc.page_content for doc in docs] + + contents = [] + for doc in docs: + context = doc.page_content + source = doc.metadata["url"] + doc_id = doc.metadata["doc_id"] + contents.append(tuple((context, source, doc_id))) return contents def set_collection_name(self, name: str): diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index df4a82f8..3309eb24 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple try: import pinecone @@ -118,7 +118,9 @@ class PineconeDB(BaseVectorDB): for i in range(0, len(docs), self.BATCH_SIZE): self.client.upsert(docs[i : i + self.BATCH_SIZE]) - def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + def query( + self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool + ) -> List[Tuple[str, str, str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -129,16 +131,22 @@ class PineconeDB(BaseVectorDB): :type where: Dict[str, any] :param skip_embedding: Optional. if True, input_query is already embedded :type skip_embedding: bool - :return: Database contents that are the result of the query - :rtype: List[str] + :return: The content of the document that matched your query, url of the source, doc_id + :rtype: List[Tuple[str,str,str]] """ if not skip_embedding: query_vector = self.embedder.embedding_fn([input_query])[0] else: query_vector = input_query - contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) - embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"])) - return embeddings + data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) + contents = [] + for doc in data["matches"]: + metadata = doc["metadata"] + context = metadata["text"] + source = metadata["url"] + doc_id = metadata["doc_id"] + contents.append(tuple((context, source, doc_id))) + return contents def set_collection_name(self, name: str): """ diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index 477fa58c..88617d5b 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -1,7 +1,7 @@ import copy import os import uuid -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple try: from qdrant_client import QdrantClient @@ -160,7 +160,9 @@ class QdrantDB(BaseVectorDB): ), ) - def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + def query( + self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool + ) -> List[Tuple[str, str, str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -172,8 +174,8 @@ class QdrantDB(BaseVectorDB): :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be generated or not :type skip_embedding: bool - :return: Database contents that are the result of the query - :rtype: List[str] + :return: The context of the document that matched your query, url of the source, doc_id + :rtype: List[Tuple[str,str,str]] """ if not skip_embedding: query_vector = self.embedder.embedding_fn([input_query])[0] @@ -199,9 +201,14 @@ class QdrantDB(BaseVectorDB): query_vector=query_vector, limit=n_results, ) + response = [] for result in results: - response.append(result.payload.get("text", "")) + context = result.payload["text"] + metadata = result.payload["metadata"] + source = metadata["url"] + doc_id = metadata["doc_id"] + response.append(tuple((context, source, doc_id))) return response def count(self) -> int: @@ -211,3 +218,15 @@ class QdrantDB(BaseVectorDB): def reset(self): self.client.delete_collection(collection_name=self.collection_name) self._initialize() + + def set_collection_name(self, name: str): + """ + Set the name of the collection. A collection is an isolated space for vectors. + + :param name: Name of the collection. + :type name: str + """ + if not isinstance(name, str): + raise TypeError("Collection name must be a string") + self.config.collection_name = name + self.collection_name = self._get_or_create_collection() diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 6416cb35..74c2bc90 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -1,6 +1,6 @@ import copy import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple try: import weaviate @@ -194,7 +194,9 @@ class WeaviateDB(BaseVectorDB): ) batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata") - def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + def query( + self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool + ) -> List[Tuple[str, str, str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -206,14 +208,15 @@ class WeaviateDB(BaseVectorDB): :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be generated or not :type skip_embedding: bool - :return: Database contents that are the result of the query - :rtype: List[str] + :return: The context of the document that matched your query, url of the source, doc_id + :rtype: List[Tuple[str,str,str]] """ if not skip_embedding: query_vector = self.embedder.embedding_fn([input_query])[0] else: query_vector = input_query keys = set(where.keys() if where is not None else set()) + data_fields = ["text"] if len(keys.intersection(self.metadata_keys)) != 0: weaviate_where_operands = [] for key in keys: @@ -231,7 +234,7 @@ class WeaviateDB(BaseVectorDB): weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands} results = ( - self.client.query.get(self.index_name, ["text"]) + self.client.query.get(self.index_name, data_fields) .with_where(weaviate_where_clause) .with_near_vector({"vector": query_vector}) .with_limit(n_results) @@ -239,16 +242,13 @@ class WeaviateDB(BaseVectorDB): ) else: results = ( - self.client.query.get(self.index_name, ["text"]) + self.client.query.get(self.index_name, data_fields) .with_near_vector({"vector": query_vector}) .with_limit(n_results) .do() ) - matched_tokens = [] - for result in results["data"]["Get"].get(self.index_name): - matched_tokens.append(result["text"]) - - return matched_tokens + contexts = results["data"]["Get"].get(self.index_name) + return contexts def set_collection_name(self, name: str): """ diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index eb99ff2a..1037c420 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Optional +import logging +from typing import Dict, List, Optional, Tuple from embedchain.config import ZillizDBConfig from embedchain.helper.json_serializable import register_deserializable @@ -61,6 +62,7 @@ class ZillizVectorDB(BaseVectorDB): :type name: str """ if utility.has_collection(name): + logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.") self.collection = Collection(name) else: fields = [ @@ -124,7 +126,9 @@ class ZillizVectorDB(BaseVectorDB): self.collection.flush() self.client.flush(self.config.collection_name) - def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]: + def query( + self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool + ) -> List[Tuple[str, str, str]]: """ Query contents from vector data base based on vector similarity @@ -135,8 +139,8 @@ class ZillizVectorDB(BaseVectorDB): :param where: to filter data :type where: str :raises InvalidDimensionException: Dimensions do not match. - :return: The content of the document that matched your query. - :rtype: List[str] + :return: The context of the document that matched your query, url of the source, doc_id + :rtype: List[Tuple[str,str,str]] """ if self.collection.is_empty: @@ -145,13 +149,14 @@ class ZillizVectorDB(BaseVectorDB): if not isinstance(where, str): where = None + output_fields = ["text", "url", "doc_id"] if skip_embedding: query_vector = input_query query_result = self.client.search( collection_name=self.config.collection_name, data=query_vector, limit=n_results, - output_fields=["text"], + output_fields=output_fields, ) else: @@ -162,13 +167,16 @@ class ZillizVectorDB(BaseVectorDB): collection_name=self.config.collection_name, data=[query_vector], limit=n_results, - output_fields=["text"], + output_fields=output_fields, ) doc_list = [] for query in query_result: - doc_list.append(query[0]["entity"]["text"]) - + data = query[0]["entity"] + context = data["text"] + source = data["url"] + doc_id = data["doc_id"] + doc_list.append(tuple((context, source, doc_id))) return doc_list def count(self) -> int: diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index b6861410..a30a972c 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -146,7 +146,7 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings): app_with_settings.db.add( embeddings=[[0, 0, 0]], documents=["document"], - metadatas=[{"value": "somevalue"}], + metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}], ids=["id"], skip_embedding=True, ) @@ -158,13 +158,13 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings): "documents": ["document"], "embeddings": None, "ids": ["id"], - "metadatas": [{"value": "somevalue"}], + "metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}], } assert data == expected_value data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) - expected_value = ["document"] + expected_value = [("document", "url_1", "doc_id_1")] assert data == expected_value app_with_settings.db.reset() @@ -299,3 +299,35 @@ def test_chroma_db_collection_reset(): app2.db.reset() app3.db.reset() app4.db.reset() + + +def test_chroma_db_collection_query(app_with_settings): + app_with_settings.db.reset() + + assert app_with_settings.db.count() == 0 + + app_with_settings.db.add( + embeddings=[[0, 0, 0]], + documents=["document"], + metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}], + ids=["id"], + skip_embedding=True, + ) + + assert app_with_settings.db.count() == 1 + + app_with_settings.db.add( + embeddings=[[0, 1, 0]], + documents=["document2"], + metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}], + ids=["id2"], + skip_embedding=True, + ) + + assert app_with_settings.db.count() == 2 + + data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True) + expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")] + + assert data == expected_value + app_with_settings.db.reset() diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index ed7036c9..3fdecea9 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -31,7 +31,7 @@ class TestEsDB(unittest.TestCase): # Create some dummy data. embeddings = [[1, 2, 3], [4, 5, 6]] documents = ["This is a document.", "This is another document."] - metadatas = [{}, {}] + metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}] ids = ["doc_1", "doc_2"] # Add the data to the database. @@ -40,8 +40,17 @@ class TestEsDB(unittest.TestCase): search_response = { "hits": { "hits": [ - {"_source": {"text": "This is a document."}, "_score": 0.9}, - {"_source": {"text": "This is another document."}, "_score": 0.8}, + { + "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}}, + "_score": 0.9, + }, + { + "_source": { + "text": "This is another document.", + "metadata": {"url": "url_2", "doc_id": "doc_id_2"}, + }, + "_score": 0.8, + }, ] } } @@ -54,7 +63,9 @@ class TestEsDB(unittest.TestCase): results = self.db.query(query, n_results=2, where={}, skip_embedding=False) # Assert that the results are correct. - self.assertEqual(results, ["This is a document.", "This is another document."]) + self.assertEqual( + results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")] + ) @patch("embedchain.vectordb.elasticsearch.Elasticsearch") def test_query_with_skip_embedding(self, mock_client): @@ -68,7 +79,7 @@ class TestEsDB(unittest.TestCase): # Create some dummy data. embeddings = [[1, 2, 3], [4, 5, 6]] documents = ["This is a document.", "This is another document."] - metadatas = [{}, {}] + metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}] ids = ["doc_1", "doc_2"] # Add the data to the database. @@ -77,8 +88,17 @@ class TestEsDB(unittest.TestCase): search_response = { "hits": { "hits": [ - {"_source": {"text": "This is a document."}, "_score": 0.9}, - {"_source": {"text": "This is another document."}, "_score": 0.8}, + { + "_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}}, + "_score": 0.9, + }, + { + "_source": { + "text": "This is another document.", + "metadata": {"url": "url_2", "doc_id": "doc_id_2"}, + }, + "_score": 0.8, + }, ] } } @@ -91,7 +111,9 @@ class TestEsDB(unittest.TestCase): results = self.db.query(query, n_results=2, where={}, skip_embedding=True) # Assert that the results are correct. - self.assertEqual(results, ["This is a document.", "This is another document."]) + self.assertEqual( + results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")] + ) def test_init_without_url(self): # Make sure it's not loaded from env diff --git a/tests/vectordb/test_zilliz_db.py b/tests/vectordb/test_zilliz_db.py index 6dca78f7..d6f8d3d5 100644 --- a/tests/vectordb/test_zilliz_db.py +++ b/tests/vectordb/test_zilliz_db.py @@ -123,7 +123,7 @@ class TestZillizDBCollection: # Mock the MilvusClient search method with patch.object(zilliz_db.client, "search") as mock_search: # Mock the search result - mock_search.return_value = [[{"entity": {"text": "result_doc"}}]] + mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]] # Call the query method with skip_embedding=True query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True) @@ -133,11 +133,11 @@ class TestZillizDBCollection: collection_name=mock_config.collection_name, data=["query_text"], limit=1, - output_fields=["text"], + output_fields=["text", "url", "doc_id"], ) # Assert that the query result matches the expected result - assert query_result == ["result_doc"] + assert query_result == [("result_doc", "url_1", "doc_id_1")] @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True) @patch("embedchain.vectordb.zilliz.connections", autospec=True) @@ -162,7 +162,7 @@ class TestZillizDBCollection: mock_embedder.embedding_fn.return_value = ["query_vector"] # Mock the search result - mock_search.return_value = [[{"entity": {"text": "result_doc"}}]] + mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]] # Call the query method with skip_embedding=False query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False) @@ -172,8 +172,8 @@ class TestZillizDBCollection: collection_name=mock_config.collection_name, data=["query_vector"], limit=1, - output_fields=["text"], + output_fields=["text", "url", "doc_id"], ) # Assert that the query result matches the expected result - assert query_result == ["result_doc"] + assert query_result == [("result_doc", "url_1", "doc_id_1")]