Change list[str] -> str for vectordbs (#1388)
This commit is contained in:
@@ -183,7 +183,7 @@ class ChromaDB(BaseVectorDB):
|
|||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
input_query: list[str],
|
input_query: str,
|
||||||
n_results: int,
|
n_results: int,
|
||||||
where: Optional[dict[str, any]] = None,
|
where: Optional[dict[str, any]] = None,
|
||||||
raw_filter: Optional[dict[str, any]] = None,
|
raw_filter: Optional[dict[str, any]] = None,
|
||||||
@@ -193,8 +193,8 @@ class ChromaDB(BaseVectorDB):
|
|||||||
"""
|
"""
|
||||||
Query contents from vector database based on vector similarity
|
Query contents from vector database based on vector similarity
|
||||||
|
|
||||||
:param input_query: list of query string
|
:param input_query: query string
|
||||||
:type input_query: list[str]
|
:type input_query: str
|
||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:type n_results: int
|
:type n_results: int
|
||||||
:param where: to filter data
|
:param where: to filter data
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
input_query: list[str],
|
input_query: str,
|
||||||
n_results: int,
|
n_results: int,
|
||||||
where: dict[str, any],
|
where: dict[str, any],
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
@@ -172,8 +172,8 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
|
|
||||||
:param input_query: list of query string
|
:param input_query: query string
|
||||||
:type input_query: list[str]
|
:type input_query: str
|
||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:type n_results: int
|
:type n_results: int
|
||||||
:param where: Optional. to filter data
|
:param where: Optional. to filter data
|
||||||
@@ -185,7 +185,7 @@ class ElasticsearchDB(BaseVectorDB):
|
|||||||
along with url of the source and doc_id (if citations flag is true)
|
along with url of the source and doc_id (if citations flag is true)
|
||||||
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
:rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
|
||||||
"""
|
"""
|
||||||
input_query_vector = self.embedder.embedding_fn(input_query)
|
input_query_vector = self.embedder.embedding_fn([input_query])
|
||||||
query_vector = input_query_vector[0]
|
query_vector = input_query_vector[0]
|
||||||
|
|
||||||
# `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
|
# `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
input_query: list[str],
|
input_query: str,
|
||||||
n_results: int,
|
n_results: int,
|
||||||
where: dict[str, any],
|
where: dict[str, any],
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
@@ -155,8 +155,8 @@ class OpenSearchDB(BaseVectorDB):
|
|||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
|
|
||||||
:param input_query: list of query string
|
:param input_query: query string
|
||||||
:type input_query: list[str]
|
:type input_query: str
|
||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:type n_results: int
|
:type n_results: int
|
||||||
:param where: Optional. to filter data
|
:param where: Optional. to filter data
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class PineconeDB(BaseVectorDB):
|
|||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
input_query: list[str],
|
input_query: str,
|
||||||
n_results: int,
|
n_results: int,
|
||||||
where: Optional[dict[str, any]] = None,
|
where: Optional[dict[str, any]] = None,
|
||||||
raw_filter: Optional[dict[str, any]] = None,
|
raw_filter: Optional[dict[str, any]] = None,
|
||||||
@@ -162,7 +162,7 @@ class PineconeDB(BaseVectorDB):
|
|||||||
Query contents from vector database based on vector similarity.
|
Query contents from vector database based on vector similarity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_query (list[str]): List of query strings.
|
input_query (str): query string.
|
||||||
n_results (int): Number of similar documents to fetch from the database.
|
n_results (int): Number of similar documents to fetch from the database.
|
||||||
where (dict[str, any], optional): Filter criteria for the search.
|
where (dict[str, any], optional): Filter criteria for the search.
|
||||||
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
|
raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ class QdrantDB(BaseVectorDB):
|
|||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
input_query: list[str],
|
input_query: str,
|
||||||
n_results: int,
|
n_results: int,
|
||||||
where: dict[str, any],
|
where: dict[str, any],
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
@@ -169,8 +169,8 @@ class QdrantDB(BaseVectorDB):
|
|||||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: query string
|
||||||
:type input_query: list[str]
|
:type input_query: str
|
||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:type n_results: int
|
:type n_results: int
|
||||||
:param where: Optional. to filter data
|
:param where: Optional. to filter data
|
||||||
|
|||||||
@@ -219,12 +219,12 @@ class WeaviateDB(BaseVectorDB):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False
|
self, input_query: str, n_results: int, where: dict[str, any], citations: bool = False
|
||||||
) -> Union[list[tuple[str, dict]], list[str]]:
|
) -> Union[list[tuple[str, dict]], list[str]]:
|
||||||
"""
|
"""
|
||||||
query contents from vector database based on vector similarity
|
query contents from vector database based on vector similarity
|
||||||
:param input_query: list of query string
|
:param input_query: query string
|
||||||
:type input_query: list[str]
|
:type input_query: str
|
||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:type n_results: int
|
:type n_results: int
|
||||||
:param where: Optional. to filter data
|
:param where: Optional. to filter data
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
|
|
||||||
def query(
|
def query(
|
||||||
self,
|
self,
|
||||||
input_query: list[str],
|
input_query: str,
|
||||||
n_results: int,
|
n_results: int,
|
||||||
where: dict[str, Any],
|
where: dict[str, Any],
|
||||||
citations: bool = False,
|
citations: bool = False,
|
||||||
@@ -147,8 +147,8 @@ class ZillizVectorDB(BaseVectorDB):
|
|||||||
"""
|
"""
|
||||||
Query contents from vector database based on vector similarity
|
Query contents from vector database based on vector similarity
|
||||||
|
|
||||||
:param input_query: list of query string
|
:param input_query: query string
|
||||||
:type input_query: list[str]
|
:type input_query: str
|
||||||
:param n_results: no of similar documents to fetch from database
|
:param n_results: no of similar documents to fetch from database
|
||||||
:type n_results: int
|
:type n_results: int
|
||||||
:param where: to filter data
|
:param where: to filter data
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class TestEsDB(unittest.TestCase):
|
|||||||
mock_client.return_value.search.return_value = search_response
|
mock_client.return_value.search.return_value = search_response
|
||||||
|
|
||||||
# Query the database for the documents that are most similar to the query "This is a document".
|
# Query the database for the documents that are most similar to the query "This is a document".
|
||||||
query = ["This is a document"]
|
query = "This is a document"
|
||||||
results_without_citations = self.db.query(query, n_results=2, where={})
|
results_without_citations = self.db.query(query, n_results=2, where={})
|
||||||
expected_results_without_citations = ["This is a document.", "This is another document."]
|
expected_results_without_citations = ["This is a document.", "This is another document."]
|
||||||
self.assertEqual(results_without_citations, expected_results_without_citations)
|
self.assertEqual(results_without_citations, expected_results_without_citations)
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class TestQdrantDB(unittest.TestCase):
|
|||||||
App(config=app_config, db=db, embedding_model=embedder)
|
App(config=app_config, db=db, embedding_model=embedder)
|
||||||
|
|
||||||
# Query for the document.
|
# Query for the document.
|
||||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
|
||||||
|
|
||||||
qdrant_client_mock.return_value.search.assert_called_once_with(
|
qdrant_client_mock.return_value.search.assert_called_once_with(
|
||||||
collection_name="embedchain-store-1536",
|
collection_name="embedchain-store-1536",
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
App(config=app_config, db=db, embedding_model=embedder)
|
App(config=app_config, db=db, embedding_model=embedder)
|
||||||
|
|
||||||
# Query for the document.
|
# Query for the document.
|
||||||
db.query(input_query=["This is a test document."], n_results=1, where={})
|
db.query(input_query="This is a test document.", n_results=1, where={})
|
||||||
|
|
||||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||||
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
|
||||||
@@ -185,7 +185,7 @@ class TestWeaviateDb(unittest.TestCase):
|
|||||||
App(config=app_config, db=db, embedding_model=embedder)
|
App(config=app_config, db=db, embedding_model=embedder)
|
||||||
|
|
||||||
# Query for the document.
|
# Query for the document.
|
||||||
db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
|
db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
|
||||||
|
|
||||||
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
|
||||||
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
weaviate_client_query_get_mock.with_where.assert_called_once_with(
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class TestZillizDBCollection:
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={})
|
query_result = zilliz_db.query(input_query="query_text", n_results=1, where={})
|
||||||
|
|
||||||
# Assert that MilvusClient.search was called with the correct parameters
|
# Assert that MilvusClient.search was called with the correct parameters
|
||||||
mock_search.assert_called_with(
|
mock_search.assert_called_with(
|
||||||
@@ -154,7 +154,7 @@ class TestZillizDBCollection:
|
|||||||
assert query_result == ["result_doc"]
|
assert query_result == ["result_doc"]
|
||||||
|
|
||||||
query_result_with_citations = zilliz_db.query(
|
query_result_with_citations = zilliz_db.query(
|
||||||
input_query=["query_text"], n_results=1, where={}, citations=True
|
input_query="query_text", n_results=1, where={}, citations=True
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_search.assert_called_with(
|
mock_search.assert_called_with(
|
||||||
|
|||||||
Reference in New Issue
Block a user