From a9d13839091ab346468a7a6869294d119dba1936 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Sun, 19 Jan 2025 04:36:49 +0530 Subject: [PATCH] Fix pytests (#2157) --- tests/vector_stores/test_elasticsearch.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/vector_stores/test_elasticsearch.py b/tests/vector_stores/test_elasticsearch.py index c4c7a170..ed1872b6 100644 --- a/tests/vector_stores/test_elasticsearch.py +++ b/tests/vector_stores/test_elasticsearch.py @@ -196,27 +196,31 @@ class TestElasticsearchDB(unittest.TestCase): } } self.client_mock.search.return_value = mock_response - + # Perform search query_vector = [0.1] * 1536 results = self.es_db.search(query=query_vector, limit=5) - + # Verify search call self.client_mock.search.assert_called_once() search_args = self.client_mock.search.call_args[1] - + # Verify search parameters self.assertEqual(search_args["index"], "test_collection") body = search_args["body"] - self.assertIn("script_score", body["query"]) - self.assertEqual( - body["query"]["script_score"]["script"]["params"]["query_vector"], - query_vector - ) + # Verify KNN query structure + self.assertIn("query", body) + self.assertIn("bool", body["query"]) + self.assertIn("must", body["query"]["bool"]) + + # Verify KNN parameters + knn_query = body["query"]["bool"]["must"][-1]["knn"]["vector"] + self.assertEqual(knn_query["vector"], query_vector) + self.assertEqual(knn_query["k"], 5) + # Verify results self.assertEqual(len(results), 1) - self.assertIsInstance(results[0], OutputData) self.assertEqual(results[0].id, "id1") self.assertEqual(results[0].score, 0.8) self.assertEqual(results[0].payload, {"key1": "value1"})