[Feature] Update db.query to return source of context (#831)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import weaviate
|
||||
@@ -194,7 +194,9 @@ class WeaviateDB(BaseVectorDB):
|
||||
)
|
||||
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
def query(
|
||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
:param input_query: list of query string
|
||||
@@ -206,14 +208,15 @@ class WeaviateDB(BaseVectorDB):
|
||||
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
|
||||
generated or not
|
||||
:type skip_embedding: bool
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
:return: The context of the document that matched your query, url of the source, doc_id
|
||||
:rtype: List[Tuple[str,str,str]]
|
||||
"""
|
||||
if not skip_embedding:
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
else:
|
||||
query_vector = input_query
|
||||
keys = set(where.keys() if where is not None else set())
|
||||
data_fields = ["text"]
|
||||
if len(keys.intersection(self.metadata_keys)) != 0:
|
||||
weaviate_where_operands = []
|
||||
for key in keys:
|
||||
@@ -231,7 +234,7 @@ class WeaviateDB(BaseVectorDB):
|
||||
weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
|
||||
|
||||
results = (
|
||||
self.client.query.get(self.index_name, ["text"])
|
||||
self.client.query.get(self.index_name, data_fields)
|
||||
.with_where(weaviate_where_clause)
|
||||
.with_near_vector({"vector": query_vector})
|
||||
.with_limit(n_results)
|
||||
@@ -239,16 +242,13 @@ class WeaviateDB(BaseVectorDB):
|
||||
)
|
||||
else:
|
||||
results = (
|
||||
self.client.query.get(self.index_name, ["text"])
|
||||
self.client.query.get(self.index_name, data_fields)
|
||||
.with_near_vector({"vector": query_vector})
|
||||
.with_limit(n_results)
|
||||
.do()
|
||||
)
|
||||
matched_tokens = []
|
||||
for result in results["data"]["Get"].get(self.index_name):
|
||||
matched_tokens.append(result["text"])
|
||||
|
||||
return matched_tokens
|
||||
contexts = results["data"]["Get"].get(self.index_name)
|
||||
return contexts
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user