[Feature] Update db.query to return source of context (#831)

This commit is contained in:
Deven Patel
2023-10-25 22:20:32 -07:00
committed by GitHub
parent a27eeb3255
commit d77e8da3f3
13 changed files with 195 additions and 73 deletions

View File

@@ -1,4 +1,5 @@
from typing import Dict, List, Optional
import logging
from typing import Dict, List, Optional, Tuple
from embedchain.config import ZillizDBConfig
from embedchain.helper.json_serializable import register_deserializable
@@ -61,6 +62,7 @@ class ZillizVectorDB(BaseVectorDB):
:type name: str
"""
if utility.has_collection(name):
logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
self.collection = Collection(name)
else:
fields = [
@@ -124,7 +126,9 @@ class ZillizVectorDB(BaseVectorDB):
self.collection.flush()
self.client.flush(self.config.collection_name)
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
"""
Query contents from vector data base based on vector similarity
@@ -135,8 +139,8 @@ class ZillizVectorDB(BaseVectorDB):
:param where: to filter data
:type where: str
:raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query.
:rtype: List[str]
:return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[Tuple[str,str,str]]
"""
if self.collection.is_empty:
@@ -145,13 +149,14 @@ class ZillizVectorDB(BaseVectorDB):
if not isinstance(where, str):
where = None
output_fields = ["text", "url", "doc_id"]
if skip_embedding:
query_vector = input_query
query_result = self.client.search(
collection_name=self.config.collection_name,
data=query_vector,
limit=n_results,
output_fields=["text"],
output_fields=output_fields,
)
else:
@@ -162,13 +167,16 @@ class ZillizVectorDB(BaseVectorDB):
collection_name=self.config.collection_name,
data=[query_vector],
limit=n_results,
output_fields=["text"],
output_fields=output_fields,
)
doc_list = []
for query in query_result:
doc_list.append(query[0]["entity"]["text"])
data = query[0]["entity"]
context = data["text"]
source = data["url"]
doc_id = data["doc_id"]
doc_list.append(tuple((context, source, doc_id)))
return doc_list
def count(self) -> int: