[Feature] Update db.query to return source of context (#831)
This commit is contained in:
@@ -146,7 +146,7 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document"],
|
||||
metadatas=[{"value": "somevalue"}],
|
||||
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
@@ -158,13 +158,13 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
|
||||
"documents": ["document"],
|
||||
"embeddings": None,
|
||||
"ids": ["id"],
|
||||
"metadatas": [{"value": "somevalue"}],
|
||||
"metadatas": [{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
}
|
||||
|
||||
assert data == expected_value
|
||||
|
||||
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True)
|
||||
expected_value = ["document"]
|
||||
expected_value = [("document", "url_1", "doc_id_1")]
|
||||
|
||||
assert data == expected_value
|
||||
app_with_settings.db.reset()
|
||||
@@ -299,3 +299,35 @@ def test_chroma_db_collection_reset():
|
||||
app2.db.reset()
|
||||
app3.db.reset()
|
||||
app4.db.reset()
|
||||
|
||||
|
||||
def test_chroma_db_collection_query(app_with_settings):
|
||||
app_with_settings.db.reset()
|
||||
|
||||
assert app_with_settings.db.count() == 0
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 0, 0]],
|
||||
documents=["document"],
|
||||
metadatas=[{"url": "url_1", "doc_id": "doc_id_1"}],
|
||||
ids=["id"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 1
|
||||
|
||||
app_with_settings.db.add(
|
||||
embeddings=[[0, 1, 0]],
|
||||
documents=["document2"],
|
||||
metadatas=[{"url": "url_2", "doc_id": "doc_id_2"}],
|
||||
ids=["id2"],
|
||||
skip_embedding=True,
|
||||
)
|
||||
|
||||
assert app_with_settings.db.count() == 2
|
||||
|
||||
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True)
|
||||
expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
|
||||
|
||||
assert data == expected_value
|
||||
app_with_settings.db.reset()
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestEsDB(unittest.TestCase):
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
@@ -40,8 +40,17 @@ class TestEsDB(unittest.TestCase):
|
||||
search_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{"_source": {"text": "This is a document."}, "_score": 0.9},
|
||||
{"_source": {"text": "This is another document."}, "_score": 0.8},
|
||||
{
|
||||
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"text": "This is another document.",
|
||||
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -54,7 +63,9 @@ class TestEsDB(unittest.TestCase):
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=False)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
self.assertEqual(
|
||||
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
||||
)
|
||||
|
||||
@patch("embedchain.vectordb.elasticsearch.Elasticsearch")
|
||||
def test_query_with_skip_embedding(self, mock_client):
|
||||
@@ -68,7 +79,7 @@ class TestEsDB(unittest.TestCase):
|
||||
# Create some dummy data.
|
||||
embeddings = [[1, 2, 3], [4, 5, 6]]
|
||||
documents = ["This is a document.", "This is another document."]
|
||||
metadatas = [{}, {}]
|
||||
metadatas = [{"url": "url_1", "doc_id": "doc_id_1"}, {"url": "url_2", "doc_id": "doc_id_2"}]
|
||||
ids = ["doc_1", "doc_2"]
|
||||
|
||||
# Add the data to the database.
|
||||
@@ -77,8 +88,17 @@ class TestEsDB(unittest.TestCase):
|
||||
search_response = {
|
||||
"hits": {
|
||||
"hits": [
|
||||
{"_source": {"text": "This is a document."}, "_score": 0.9},
|
||||
{"_source": {"text": "This is another document."}, "_score": 0.8},
|
||||
{
|
||||
"_source": {"text": "This is a document.", "metadata": {"url": "url_1", "doc_id": "doc_id_1"}},
|
||||
"_score": 0.9,
|
||||
},
|
||||
{
|
||||
"_source": {
|
||||
"text": "This is another document.",
|
||||
"metadata": {"url": "url_2", "doc_id": "doc_id_2"},
|
||||
},
|
||||
"_score": 0.8,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -91,7 +111,9 @@ class TestEsDB(unittest.TestCase):
|
||||
results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
|
||||
|
||||
# Assert that the results are correct.
|
||||
self.assertEqual(results, ["This is a document.", "This is another document."])
|
||||
self.assertEqual(
|
||||
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
|
||||
)
|
||||
|
||||
def test_init_without_url(self):
|
||||
# Make sure it's not loaded from env
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestZillizDBCollection:
|
||||
# Mock the MilvusClient search method
|
||||
with patch.object(zilliz_db.client, "search") as mock_search:
|
||||
# Mock the search result
|
||||
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
|
||||
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
|
||||
|
||||
# Call the query method with skip_embedding=True
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
|
||||
@@ -133,11 +133,11 @@ class TestZillizDBCollection:
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_text"],
|
||||
limit=1,
|
||||
output_fields=["text"],
|
||||
output_fields=["text", "url", "doc_id"],
|
||||
)
|
||||
|
||||
# Assert that the query result matches the expected result
|
||||
assert query_result == ["result_doc"]
|
||||
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
||||
|
||||
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
|
||||
@patch("embedchain.vectordb.zilliz.connections", autospec=True)
|
||||
@@ -162,7 +162,7 @@ class TestZillizDBCollection:
|
||||
mock_embedder.embedding_fn.return_value = ["query_vector"]
|
||||
|
||||
# Mock the search result
|
||||
mock_search.return_value = [[{"entity": {"text": "result_doc"}}]]
|
||||
mock_search.return_value = [[{"entity": {"text": "result_doc", "url": "url_1", "doc_id": "doc_id_1"}}]]
|
||||
|
||||
# Call the query method with skip_embedding=False
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
|
||||
@@ -172,8 +172,8 @@ class TestZillizDBCollection:
|
||||
collection_name=mock_config.collection_name,
|
||||
data=["query_vector"],
|
||||
limit=1,
|
||||
output_fields=["text"],
|
||||
output_fields=["text", "url", "doc_id"],
|
||||
)
|
||||
|
||||
# Assert that the query result matches the expected result
|
||||
assert query_result == ["result_doc"]
|
||||
assert query_result == [("result_doc", "url_1", "doc_id_1")]
|
||||
|
||||
Reference in New Issue
Block a user