Change list[str] -> str for vectordbs (#1388)
This commit is contained in:
@@ -58,7 +58,7 @@ class TestEsDB(unittest.TestCase):
|
||||
mock_client.return_value.search.return_value = search_response
|
||||
|
||||
# Query the database for the documents that are most similar to the query "This is a document".
|
||||
query = ["This is a document"]
|
||||
query = "This is a document"
|
||||
results_without_citations = self.db.query(query, n_results=2, where={})
|
||||
expected_results_without_citations = ["This is a document.", "This is another document."]
|
||||
self.assertEqual(results_without_citations, expected_results_without_citations)
|
||||
|
||||
@@ -114,7 +114,7 @@ class TestQdrantDB(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
|
||||
|
||||
qdrant_client_mock.return_value.search.assert_called_once_with(
|
||||
collection_name="embedchain-store-1536",
|
||||
|
||||
@@ -161,7 +161,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={})
|
||||
db.query(input_query="This is a test document.", n_results=1, where={})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||
@@ -185,7 +185,7 @@ class TestWeaviateDb(unittest.TestCase):
|
||||
App(config=app_config, db=db, embedding_model=embedder)
|
||||
|
||||
# Query for the document.
|
||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
||||
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
|
||||
|
||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
||||
|
||||
@@ -139,7 +139,7 @@ class TestZillizDBCollection:
|
||||
]
|
||||
]
|
||||
|
||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={})
|
||||
query_result = zilliz_db.query(input_query="query_text", n_results=1, where={})
|
||||
|
||||
# Assert that MilvusClient.search was called with the correct parameters
|
||||
mock_search.assert_called_with(
|
||||
@@ -154,7 +154,7 @@ class TestZillizDBCollection:
|
||||
assert query_result == ["result_doc"]
|
||||
|
||||
query_result_with_citations = zilliz_db.query(
|
||||
input_query=["query_text"], n_results=1, where={}, citations=True
|
||||
input_query="query_text", n_results=1, where={}, citations=True
|
||||
)
|
||||
|
||||
mock_search.assert_called_with(
|
||||
|
||||
Reference in New Issue
Block a user