[Feature] Add citations flag in query and chat functions of App to return context along with the answer (#859)
This commit is contained in:
@@ -163,10 +163,12 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
||||
|
||||
assert data == expected_value
|
||||
|
||||
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
||||
expected_value = [("document", "url_1", "doc_id_1")]
|
||||
data_without_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
|
||||
)
|
||||
expected_value_without_citations = ["document"]
|
||||
assert data_without_citations == expected_value_without_citations
|
||||
|
||||
assert data == expected_value
|
||||
app_with_settings.db.reset()
|
||||
|
||||
|
||||
@@ -326,8 +328,16 @@ def test_chroma_db_collection_query(app_with_settings):
|
||||
|
||||
assert app_with_settings.db.count() == 2
|
||||
|
||||
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
|
||||
expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
|
||||
data_without_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
|
||||
)
|
||||
expected_value_without_citations = ["document", "document2"]
|
||||
assert data_without_citations == expected_value_without_citations
|
||||
|
||||
data_with_citations = app_with_settings.db.query(
|
||||
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
|
||||
)
|
||||
expected_value_with_citations = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
|
||||
assert data_with_citations == expected_value_with_citations
|
||||
|
||||
assert data == expected_value
|
||||
app_with_settings.db.reset()
|
||||
|
||||
@@ -60,12 +60,16 @@ class TestEsDB(unittest.TestCase):
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||
results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||
expected_results_without_citations = ["This is a document.", "This is another document."]
|
||||
self.assertEqual(results_without_citations, expected_results_without_citations)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(
|
||||
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
||||
)
|
||||
results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
|
||||
expected_results_with_citations = [
|
||||
("This is a document.", "url_1", "doc_id_1"),
|
||||
("This is another document.", "url_2", "doc_id_2"),
|
||||
]
|
||||
self.assertEqual(results_with_citations, expected_results_with_citations)
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
@@ -111,9 +115,7 @@ class TestEsDB(unittest.TestCase):
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(
|
||||
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
||||
)
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
|
||||
def test_init_without_url(self):
|
||||
# Make sure it's not loaded from env
|
||||
|
||||
@@ -75,10 +75,6 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
"name": "app_id",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
{
|
||||
"name": "text",
|
||||
"dataType": ["text"],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -129,7 +129,7 @@ class TestZillizDBCollection:
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_once_with(
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_text"],
|
||||
limit=1,
|
||||
@@ -137,7 +137,20 @@ class TestZillizDBCollection:
|
||||
)
|
||||
|
||||
# Assert that the query result matches the expected result
|
||||
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_text"],
|
||||
limit=1,
|
||||
output_fields=["text", "url", "doc_id"],
|
||||
)
|
||||
|
||||
assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||
@@ -168,7 +181,7 @@ class TestZillizDBCollection:
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_once_with(
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_vector"],
|
||||
limit=1,
|
||||
@@ -176,4 +189,17 @@ class TestZillizDBCollection:
|
||||
)
|
||||
|
||||
# Assert that the query result matches the expected result
|
||||
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_vector"],
|
||||
limit=1,
|
||||
output_fields=["text", "url", "doc_id"],
|
||||
)
|
||||
|
||||
assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
|
||||
|
||||
Reference in New Issue
Block a user