diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 34933203..b1e2a454 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -4,7 +4,7 @@ import logging import os import sqlite3 from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from dotenv import load_dotenv from langchain.docstore.document import Document @@ -438,7 +438,9 @@ class EmbedChain(JSONSerializable): ) ] - def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]: + def retrieve_from_database( + self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ Queries the vector database based on the given input query. Gets relevant doc based on the query @@ -449,6 +451,8 @@ class EmbedChain(JSONSerializable): :type config: Optional[BaseLlmConfig], optional :param where: A dictionary of key-value pairs to filter the database results, defaults to None :type where: _type_, optional + :param citations: A boolean to indicate if db should fetch citation source + :type citations: bool :return: List of contents of the document that matched your query :rtype: List[str] """ @@ -478,14 +482,19 @@ class EmbedChain(JSONSerializable): n_results=query_config.number_documents, where=where, skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), + citations=citations, ) - if len(contexts) > 0 and isinstance(contexts[0], tuple): - contexts = list(map(lambda x: x[0], contexts)) - return contexts - def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str: + def query( + self, + input_query: str, + config: BaseLlmConfig = None, + dry_run=False, + where: Optional[Dict] = None, + **kwargs: Dict[str, Any], + ) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]: """ Queries the vector database based on the given input query. Gets relevant doc based on the query and then passes it to an @@ -501,15 +510,31 @@ class EmbedChain(JSONSerializable): :type dry_run: bool, optional :param where: A dictionary of key-value pairs to filter the database results., defaults to None :type where: Optional[Dict[str, str]], optional - :return: The answer to the query or the dry run result - :rtype: str + :param kwargs: To read more params for the query function. Ex. we use citations boolean + param to return context along with the answer + :type kwargs: Dict[str, Any] + :return: The answer to the query, with citations if the citation flag is True + or the dry run result + :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ - contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) - answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) + citations = kwargs.get("citations", False) + contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations) + if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): + contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) + else: + contexts_data_for_llm_query = contexts + + answer = self.llm.query( + input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run + ) # Send anonymous telemetry self.telemetry.capture(event_name="query", properties=self._telemetry_props) - return answer + + if citations: + return answer, contexts + else: + return answer def chat( self, @@ -517,6 +542,7 @@ class EmbedChain(JSONSerializable): config: Optional[BaseLlmConfig] = None, dry_run=False, where: Optional[Dict[str, str]] = None, + **kwargs: Dict[str, Any], ) -> str: """ Queries the vector database on the given input query. @@ -535,15 +561,31 @@ class EmbedChain(JSONSerializable): :type dry_run: bool, optional :param where: A dictionary of key-value pairs to filter the database results., defaults to None :type where: Optional[Dict[str, str]], optional - :return: The answer to the query or the dry run result - :rtype: str + :param kwargs: To read more params for the query function. Ex. we use citations boolean + param to return context along with the answer + :type kwargs: Dict[str, Any] + :return: The answer to the query, with citations if the citation flag is True + or the dry run result + :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ - contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) - answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) + citations = kwargs.get("citations", False) + contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations) + if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): + contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) + else: + contexts_data_for_llm_query = contexts + + answer = self.llm.chat( + input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run + ) + # Send anonymous telemetry self.telemetry.capture(event_name="chat", properties=self._telemetry_props) - return answer + if citations: + return answer, contexts + else: + return answer def set_collection_name(self, name: str): """ diff --git a/embedchain/pipeline.py b/embedchain/pipeline.py index 8fad79a5..c0a35c3a 100644 --- a/embedchain/pipeline.py +++ b/embedchain/pipeline.py @@ -234,6 +234,7 @@ class Pipeline(EmbedChain): n_results=num_documents, where=where, skip_embedding=False, + citations=True, ) result = [] for c in context: diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 90195459..e86651f3 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from chromadb import Collection, QueryResult from langchain.docstore.document import Document @@ -192,8 +192,13 @@ class ChromaDB(BaseVectorDB): ] def query( - self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool - ) -> List[Tuple[str, str, str]]: + self, + input_query: List[str], + n_results: int, + where: Dict[str, any], + skip_embedding: bool, + citations: bool = False, + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ Query contents from vector database based on vector similarity @@ -205,9 +210,12 @@ class ChromaDB(BaseVectorDB): :type where: Dict[str, Any] :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :type skip_embedding: bool + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. :raises InvalidDimensionException: Dimensions do not match. - :return: The content of the document that matched your query, url of the source, doc_id - :rtype: List[Tuple[str,str,str]] + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] """ try: if skip_embedding: @@ -236,10 +244,13 @@ class ChromaDB(BaseVectorDB): contexts = [] for result in results_formatted: context = result[0].page_content - metadata = result[0].metadata - source = metadata["url"] - doc_id = metadata["doc_id"] - contexts.append((context, source, doc_id)) + if citations: + metadata = result[0].metadata + source = metadata["url"] + doc_id = metadata["doc_id"] + contexts.append((context, source, doc_id)) + else: + contexts.append(context) return contexts def set_collection_name(self, name: str): diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index aeb627d5..b2737080 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union try: from elasticsearch import Elasticsearch @@ -136,8 +136,13 @@ class ElasticsearchDB(BaseVectorDB): self.client.indices.refresh(index=self._get_index()) def query( - self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool - ) -> List[Tuple[str, str, str]]: + self, + input_query: List[str], + n_results: int, + where: Dict[str, any], + skip_embedding: bool, + citations: bool = False, + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector data base based on vector similarity @@ -150,8 +155,11 @@ class ElasticsearchDB(BaseVectorDB): :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :type skip_embedding: bool :return: The context of the document that matched your query, url of the source, doc_id - - :rtype: List[Tuple[str,str,str]] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] """ if skip_embedding: query_vector = input_query @@ -175,14 +183,17 @@ class ElasticsearchDB(BaseVectorDB): _source = ["text", "metadata.url", "metadata.doc_id"] response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results) docs = response["hits"]["hits"] - contents = [] + contexts = [] for doc in docs: context = doc["_source"]["text"] - metadata = doc["_source"]["metadata"] - source = metadata["url"] - doc_id = metadata["doc_id"] - contents.append(tuple((context, source, doc_id))) - return contents + if citations: + metadata = doc["_source"]["metadata"] + source = metadata["url"] + doc_id = metadata["doc_id"] + contexts.append(tuple((context, source, doc_id))) + else: + contexts.append(context) + return contexts def set_collection_name(self, name: str): """ diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 1ba29e82..cb548166 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Union try: from opensearchpy import OpenSearch @@ -146,8 +146,13 @@ class OpenSearchDB(BaseVectorDB): self.client.indices.refresh(index=self._get_index()) def query( - self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool - ) -> List[Tuple[str, str, str]]: + self, + input_query: List[str], + n_results: int, + where: Dict[str, any], + skip_embedding: bool, + citations: bool = False, + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector data base based on vector similarity @@ -159,8 +164,11 @@ class OpenSearchDB(BaseVectorDB): :type where: Dict[str, any] :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :type skip_embedding: bool - :return: The content of the document that matched your query, url of the source, doc_id - :rtype: List[Tuple[str,str,str]] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] """ # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists embeddings = OpenAIEmbeddings() @@ -188,13 +196,16 @@ class OpenSearchDB(BaseVectorDB): k=n_results, ) - contents = [] + contexts = [] for doc in docs: context = doc.page_content - source = doc.metadata["url"] - doc_id = doc.metadata["doc_id"] - contents.append(tuple((context, source, doc_id))) - return contents + if citations: + source = doc.metadata["url"] + doc_id = doc.metadata["doc_id"] + contexts.append(tuple((context, source, doc_id))) + else: + contexts.append(context) + return contexts def set_collection_name(self, name: str): """ diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 3309eb24..86a817ac 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union try: import pinecone @@ -119,8 +119,13 @@ class PineconeDB(BaseVectorDB): self.client.upsert(docs[i : i + self.BATCH_SIZE]) def query( - self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool - ) -> List[Tuple[str, str, str]]: + self, + input_query: List[str], + n_results: int, + where: Dict[str, any], + skip_embedding: bool, + citations: bool = False, + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -131,22 +136,28 @@ class PineconeDB(BaseVectorDB): :type where: Dict[str, any] :param skip_embedding: Optional. if True, input_query is already embedded :type skip_embedding: bool - :return: The content of the document that matched your query, url of the source, doc_id - :rtype: List[Tuple[str,str,str]] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] """ if not skip_embedding: query_vector = self.embedder.embedding_fn([input_query])[0] else: query_vector = input_query data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) - contents = [] + contexts = [] for doc in data["matches"]: metadata = doc["metadata"] context = metadata["text"] - source = metadata["url"] - doc_id = metadata["doc_id"] - contents.append(tuple((context, source, doc_id))) - return contents + if citations: + source = metadata["url"] + doc_id = metadata["doc_id"] + contexts.append(tuple((context, source, doc_id))) + else: + contexts.append(context) + return contexts def set_collection_name(self, name: str): """ diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index 88617d5b..3fe3888b 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -1,7 +1,7 @@ import copy import os import uuid -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union try: from qdrant_client import QdrantClient @@ -161,8 +161,13 @@ class QdrantDB(BaseVectorDB): ) def query( - self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool - ) -> List[Tuple[str, str, str]]: + self, + input_query: List[str], + n_results: int, + where: Dict[str, any], + skip_embedding: bool, + citations: bool = False, + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -174,8 +179,11 @@ class QdrantDB(BaseVectorDB): :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be generated or not :type skip_embedding: bool - :return: The context of the document that matched your query, url of the source, doc_id - :rtype: List[Tuple[str,str,str]] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] """ if not skip_embedding: query_vector = self.embedder.embedding_fn([input_query])[0] @@ -202,14 +210,17 @@ class QdrantDB(BaseVectorDB): limit=n_results, ) - response = [] + contexts = [] for result in results: context = result.payload["text"] - metadata = result.payload["metadata"] - source = metadata["url"] - doc_id = metadata["doc_id"] - response.append(tuple((context, source, doc_id))) - return response + if citations: + metadata = result.payload["metadata"] + source = metadata["url"] + doc_id = metadata["doc_id"] + contexts.append(tuple((context, source, doc_id))) + else: + contexts.append(context) + return contexts def count(self) -> int: response = self.client.get_collection(collection_name=self.collection_name) diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 74c2bc90..fde91caf 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -1,6 +1,6 @@ import copy import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union try: import weaviate @@ -58,10 +58,14 @@ class WeaviateDB(BaseVectorDB): raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") self.index_name = self._get_index_name() - self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"} + self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"} if not self.client.schema.exists(self.index_name): # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier # The none vectorizer is crucial as we have our own custom embedding function + """ + TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying. + Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below. + """ class_obj = { "classes": [ { @@ -106,10 +110,6 @@ class WeaviateDB(BaseVectorDB): "name": "app_id", "dataType": ["text"], }, - { - "name": "text", - "dataType": ["text"], - }, ], }, ] @@ -195,8 +195,13 @@ class WeaviateDB(BaseVectorDB): batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata") def query( - self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool - ) -> List[Tuple[str, str, str]]: + self, + input_query: List[str], + n_results: int, + where: Dict[str, any], + skip_embedding: bool, + citations: bool = False, + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ query contents from vector database based on vector similarity :param input_query: list of query string @@ -208,15 +213,23 @@ class WeaviateDB(BaseVectorDB): :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be generated or not :type skip_embedding: bool - :return: The context of the document that matched your query, url of the source, doc_id - :rtype: List[Tuple[str,str,str]] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] """ if not skip_embedding: query_vector = self.embedder.embedding_fn([input_query])[0] else: query_vector = input_query + keys = set(where.keys() if where is not None else set()) data_fields = ["text"] + + if citations: + data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys))) + if len(keys.intersection(self.metadata_keys)) != 0: weaviate_where_operands = [] for key in keys: @@ -247,7 +260,18 @@ class WeaviateDB(BaseVectorDB): .with_limit(n_results) .do() ) - contexts = results["data"]["Get"].get(self.index_name) + + docs = results["data"]["Get"].get(self.index_name) + contexts = [] + for doc in docs: + context = doc["text"] + if citations: + metadata = doc["metadata"][0] + source = metadata["url"] + doc_id = metadata["doc_id"] + contexts.append((context, source, doc_id)) + else: + contexts.append(context) return contexts def set_collection_name(self, name: str): diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index 1037c420..f0f25eb1 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union from embedchain.config import ZillizDBConfig from embedchain.helper.json_serializable import register_deserializable @@ -127,8 +127,13 @@ class ZillizVectorDB(BaseVectorDB): self.client.flush(self.config.collection_name) def query( - self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool - ) -> List[Tuple[str, str, str]]: + self, + input_query: List[str], + n_results: int, + where: Dict[str, any], + skip_embedding: bool, + citations: bool = False, + ) -> Union[List[Tuple[str, str, str]], List[str]]: """ Query contents from vector data base based on vector similarity @@ -139,8 +144,11 @@ class ZillizVectorDB(BaseVectorDB): :param where: to filter data :type where: str :raises InvalidDimensionException: Dimensions do not match. - :return: The context of the document that matched your query, url of the source, doc_id - :rtype: List[Tuple[str,str,str]] + :param citations: we use citations boolean param to return context along with the answer. + :type citations: bool, default is False. + :return: The content of the document that matched your query, + along with url of the source and doc_id (if citations flag is true) + :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]] """ if self.collection.is_empty: @@ -170,14 +178,17 @@ class ZillizVectorDB(BaseVectorDB): output_fields=output_fields, ) - doc_list = [] + contexts = [] for query in query_result: data = query[0]["entity"] context = data["text"] - source = data["url"] - doc_id = data["doc_id"] - doc_list.append(tuple((context, source, doc_id))) - return doc_list + if citations: + source = data["url"] + doc_id = data["doc_id"] + contexts.append(tuple((context, source, doc_id))) + else: + contexts.append(context) + return contexts def count(self) -> int: """ diff --git a/poetry.lock b/poetry.lock index 4ec628f3..0bc4c0e4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7141,4 +7141,4 @@ whatsapp = ["flask", "twilio"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "0b83ba3fd2485b3b4aa3c6a7534b214378d349538f7eb63c65768aafecdfad60" +content-hash = "0b83ba3fd2485b3b4aa3c6a7534b214378d349538f7eb63c65768aafecdfad60" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f1c1af06..8ec7e2ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.0.88" +version = "0.0.89" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index a30a972c..8b854176 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -163,10 +163,12 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings): assert data == expected_value - data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) - expected_value = [("document", "url_1", "doc_id_1")] + data_without_citations = app_with_settings.db.query( + input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True + ) + expected_value_without_citations = ["document"] + assert data_without_citations == expected_value_without_citations - assert data == expected_value app_with_settings.db.reset() @@ -326,8 +328,16 @@ def test_chroma_db_collection_query(app_with_settings): assert app_with_settings.db.count() == 2 - data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True) - expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")] + data_without_citations = app_with_settings.db.query( + input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True + ) + expected_value_without_citations = ["document", "document2"] + assert data_without_citations == expected_value_without_citations + + data_with_citations = app_with_settings.db.query( + input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True + ) + expected_value_with_citations = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")] + assert data_with_citations == expected_value_with_citations - assert data == expected_value app_with_settings.db.reset() diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 3fdecea9..75c54c57 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -60,12 +60,16 @@ class TestEsDB(unittest.TestCase): # Query the database for the documents that are most similar to the query "This is a document". query = ["This is a document"] - results = self.db.query(query, n_results=2, where={}, skip_embedding=False) + results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False) + expected_results_without_citations = ["This is a document.", "This is another document."] + self.assertEqual(results_without_citations, expected_results_without_citations) - # Assert that the results are correct. - self.assertEqual( - results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")] - ) + results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True) + expected_results_with_citations = [ + ("This is a document.", "url_1", "doc_id_1"), + ("This is another document.", "url_2", "doc_id_2"), + ] + self.assertEqual(results_with_citations, expected_results_with_citations) @patch("embedchain.vectordb.elasticsearch.Elasticsearch") def test_query_with_skip_embedding(self, mock_client): @@ -111,9 +115,7 @@ class TestEsDB(unittest.TestCase): results = self.db.query(query, n_results=2, where={}, skip_embedding=True) # Assert that the results are correct. - self.assertEqual( - results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")] - ) + self.assertEqual(results, ["This is a document.", "This is another document."]) def test_init_without_url(self): # Make sure it's not loaded from env diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py index 5ced280a..43dfe989 100644 --- a/tests/vectordb/test_weaviate.py +++ b/tests/vectordb/test_weaviate.py @@ -75,10 +75,6 @@ class TestWeaviateDb(unittest.TestCase): "name": "app_id", "dataType": ["text"], }, - { - "name": "text", - "dataType": ["text"], - }, ], }, ] diff --git a/tests/vectordb/test_zilliz_db.py b/tests/vectordb/test_zilliz_db.py index d6f8d3d5..80cbf205 100644 --- a/tests/vectordb/test_zilliz_db.py +++ b/tests/vectordb/test_zilliz_db.py @@ -129,7 +129,7 @@ class TestZillizDBCollection: query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True) # Assert that MilvusClient.search was called with the correct parameters - mock_search.assert_called_once_with( + mock_search.assert_called_with( collection_name=mock_config.collection_name, data=["query_text"], limit=1, @@ -137,7 +137,20 @@ class TestZillizDBCollection: ) # Assert that the query result matches the expected result - assert query_result == [("result_doc", "url_1", "doc_id_1")] + assert query_result == ["result_doc"] + + query_result_with_citations = zilliz_db.query( + input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True + ) + + mock_search.assert_called_with( + collection_name=mock_config.collection_name, + data=["query_text"], + limit=1, + output_fields=["text", "url", "doc_id"], + ) + + assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")] @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True) @patch("embedchain.vectordb.zilliz.connections", autospec=True) @@ -168,7 +181,7 @@ class TestZillizDBCollection: query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False) # Assert that MilvusClient.search was called with the correct parameters - mock_search.assert_called_once_with( + mock_search.assert_called_with( collection_name=mock_config.collection_name, data=["query_vector"], limit=1, @@ -176,4 +189,17 @@ class TestZillizDBCollection: ) # Assert that the query result matches the expected result - assert query_result == [("result_doc", "url_1", "doc_id_1")] + assert query_result == ["result_doc"] + + query_result_with_citations = zilliz_db.query( + input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True + ) + + mock_search.assert_called_with( + collection_name=mock_config.collection_name, + data=["query_vector"], + limit=1, + output_fields=["text", "url", "doc_id"], + ) + + assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]