diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index f3b87b7f..31dc2615 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -183,7 +183,7 @@ class ChromaDB(BaseVectorDB): def query( self, - input_query: list[str], + input_query: str, n_results: int, where: Optional[dict[str, any]] = None, raw_filter: Optional[dict[str, any]] = None, @@ -193,8 +193,8 @@ class ChromaDB(BaseVectorDB): """ Query contents from vector database based on vector similarity - :param input_query: list of query string - :type input_query: list[str] + :param input_query: query string + :type input_query: str :param n_results: no of similar documents to fetch from database :type n_results: int :param where: to filter data diff --git a/embedchain/vectordb/elasticsearch.py b/embedchain/vectordb/elasticsearch.py index ba0b2510..5611af3e 100644 --- a/embedchain/vectordb/elasticsearch.py +++ b/embedchain/vectordb/elasticsearch.py @@ -163,7 +163,7 @@ class ElasticsearchDB(BaseVectorDB): def query( self, - input_query: list[str], + input_query: str, n_results: int, where: dict[str, any], citations: bool = False, @@ -172,8 +172,8 @@ class ElasticsearchDB(BaseVectorDB): """ query contents from vector database based on vector similarity - :param input_query: list of query string - :type input_query: list[str] + :param input_query: query string + :type input_query: str :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data @@ -185,7 +185,7 @@ class ElasticsearchDB(BaseVectorDB): along with url of the source and doc_id (if citations flag is true) :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]] """ - input_query_vector = self.embedder.embedding_fn(input_query) + input_query_vector = self.embedder.embedding_fn([input_query]) query_vector = input_query_vector[0] # `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html` diff --git a/embedchain/vectordb/opensearch.py b/embedchain/vectordb/opensearch.py index 99d73300..a5339b3a 100644 --- a/embedchain/vectordb/opensearch.py +++ b/embedchain/vectordb/opensearch.py @@ -146,7 +146,7 @@ class OpenSearchDB(BaseVectorDB): def query( self, - input_query: list[str], + input_query: str, n_results: int, where: dict[str, any], citations: bool = False, @@ -155,8 +155,8 @@ class OpenSearchDB(BaseVectorDB): """ query contents from vector database based on vector similarity - :param input_query: list of query string - :type input_query: list[str] + :param input_query: query string + :type input_query: str :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data diff --git a/embedchain/vectordb/pinecone.py b/embedchain/vectordb/pinecone.py index 710e0611..bbf5b2ca 100644 --- a/embedchain/vectordb/pinecone.py +++ b/embedchain/vectordb/pinecone.py @@ -150,7 +150,7 @@ class PineconeDB(BaseVectorDB): def query( self, - input_query: list[str], + input_query: str, n_results: int, where: Optional[dict[str, any]] = None, raw_filter: Optional[dict[str, any]] = None, @@ -162,7 +162,7 @@ class PineconeDB(BaseVectorDB): Query contents from vector database based on vector similarity. Args: - input_query (list[str]): List of query strings. + input_query (str): query string. n_results (int): Number of similar documents to fetch from the database. where (dict[str, any], optional): Filter criteria for the search. raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search. diff --git a/embedchain/vectordb/qdrant.py b/embedchain/vectordb/qdrant.py index 9ed58898..1295b91c 100644 --- a/embedchain/vectordb/qdrant.py +++ b/embedchain/vectordb/qdrant.py @@ -161,7 +161,7 @@ class QdrantDB(BaseVectorDB): def query( self, - input_query: list[str], + input_query: str, n_results: int, where: dict[str, any], citations: bool = False, @@ -169,8 +169,8 @@ class QdrantDB(BaseVectorDB): ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity - :param input_query: list of query string - :type input_query: list[str] + :param input_query: query string + :type input_query: str :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data diff --git a/embedchain/vectordb/weaviate.py b/embedchain/vectordb/weaviate.py index 13693f61..b5a76b84 100644 --- a/embedchain/vectordb/weaviate.py +++ b/embedchain/vectordb/weaviate.py @@ -219,12 +219,12 @@ class WeaviateDB(BaseVectorDB): ) def query( - self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False + self, input_query: str, n_results: int, where: dict[str, any], citations: bool = False ) -> Union[list[tuple[str, dict]], list[str]]: """ query contents from vector database based on vector similarity - :param input_query: list of query string - :type input_query: list[str] + :param input_query: query string + :type input_query: str :param n_results: no of similar documents to fetch from database :type n_results: int :param where: Optional. to filter data diff --git a/embedchain/vectordb/zilliz.py b/embedchain/vectordb/zilliz.py index 14663614..0a30de9d 100644 --- a/embedchain/vectordb/zilliz.py +++ b/embedchain/vectordb/zilliz.py @@ -138,7 +138,7 @@ class ZillizVectorDB(BaseVectorDB): def query( self, - input_query: list[str], + input_query: str, n_results: int, where: dict[str, Any], citations: bool = False, @@ -147,8 +147,8 @@ class ZillizVectorDB(BaseVectorDB): """ Query contents from vector database based on vector similarity - :param input_query: list of query string - :type input_query: list[str] + :param input_query: query string + :type input_query: str :param n_results: no of similar documents to fetch from database :type n_results: int :param where: to filter data diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 5445dbb6..2cb42c42 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -58,7 +58,7 @@ class TestEsDB(unittest.TestCase): mock_client.return_value.search.return_value = search_response # Query the database for the documents that are most similar to the query "This is a document". - query = ["This is a document"] + query = "This is a document" results_without_citations = self.db.query(query, n_results=2, where={}) expected_results_without_citations = ["This is a document.", "This is another document."] self.assertEqual(results_without_citations, expected_results_without_citations) diff --git a/tests/vectordb/test_qdrant.py b/tests/vectordb/test_qdrant.py index 2bc6cb86..288874e1 100644 --- a/tests/vectordb/test_qdrant.py +++ b/tests/vectordb/test_qdrant.py @@ -114,7 +114,7 @@ class TestQdrantDB(unittest.TestCase): App(config=app_config, db=db, embedding_model=embedder) # Query for the document. - db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}) + db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"}) qdrant_client_mock.return_value.search.assert_called_once_with( collection_name="embedchain-store-1536", diff --git a/tests/vectordb/test_weaviate.py b/tests/vectordb/test_weaviate.py index 66e75e76..805e4aaf 100644 --- a/tests/vectordb/test_weaviate.py +++ b/tests/vectordb/test_weaviate.py @@ -161,7 +161,7 @@ class TestWeaviateDb(unittest.TestCase): App(config=app_config, db=db, embedding_model=embedder) # Query for the document. - db.query(input_query=["This is a test document."], n_results=1, where={}) + db.query(input_query="This is a test document.", n_results=1, where={}) weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"]) weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]}) @@ -185,7 +185,7 @@ class TestWeaviateDb(unittest.TestCase): App(config=app_config, db=db, embedding_model=embedder) # Query for the document. - db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}) + db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"}) weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"]) weaviate_client_query_get_mock.with_where.assert_called_once_with( diff --git a/tests/vectordb/test_zilliz_db.py b/tests/vectordb/test_zilliz_db.py index e529695f..7d636044 100644 --- a/tests/vectordb/test_zilliz_db.py +++ b/tests/vectordb/test_zilliz_db.py @@ -139,7 +139,7 @@ class TestZillizDBCollection: ] ] - query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}) + query_result = zilliz_db.query(input_query="query_text", n_results=1, where={}) # Assert that MilvusClient.search was called with the correct parameters mock_search.assert_called_with( @@ -154,7 +154,7 @@ class TestZillizDBCollection: assert query_result == ["result_doc"] query_result_with_citations = zilliz_db.query( - input_query=["query_text"], n_results=1, where={}, citations=True + input_query="query_text", n_results=1, where={}, citations=True ) mock_search.assert_called_with(