[Feature] Add citations flag in query and chat functions of App to return context along with the answer (#859)

This commit is contained in:
Deven Patel
2023-11-01 13:06:28 -07:00
committed by GitHub
parent 5022c1ae29
commit 930280f4ce
15 changed files with 279 additions and 112 deletions

View File

@@ -1,6 +1,6 @@
import copy
import os
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
try:
import weaviate
@@ -58,10 +58,14 @@ class WeaviateDB(BaseVectorDB):
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
self.index_name = self._get_index_name()
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
if not self.client.schema.exists(self.index_name):
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
# The none vectorizer is crucial as we have our own custom embedding function
"""
TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
"""
class_obj = {
"classes": [
{
@@ -106,10 +110,6 @@ class WeaviateDB(BaseVectorDB):
"name": "app_id",
"dataType": ["text"],
},
{
"name": "text",
"dataType": ["text"],
},
],
},
]
@@ -195,8 +195,13 @@ class WeaviateDB(BaseVectorDB):
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
self,
input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
@@ -208,15 +213,23 @@ class WeaviateDB(BaseVectorDB):
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not
:type skip_embedding: bool
:return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[Tuple[str,str,str]]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0]
else:
query_vector = input_query
keys = set(where.keys() if where is not None else set())
data_fields = ["text"]
if citations:
data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
if len(keys.intersection(self.metadata_keys)) != 0:
weaviate_where_operands = []
for key in keys:
@@ -247,7 +260,18 @@ class WeaviateDB(BaseVectorDB):
.with_limit(n_results)
.do()
)
contexts = results["data"]["Get"].get(self.index_name)
docs = results["data"]["Get"].get(self.index_name)
contexts = []
for doc in docs:
context = doc["text"]
if citations:
metadata = doc["metadata"][0]
source = metadata["url"]
doc_id = metadata["doc_id"]
contexts.append((context, source, doc_id))
else:
contexts.append(context)
return contexts
def set_collection_name(self, name: str):