[Feature] Add citations flag in query and chat functions of App to return context along with the answer (#859)

This commit is contained in:
Deven Patel
2023-11-01 13:06:28 -07:00
committed by GitHub
parent 5022c1ae29
commit 930280f4ce
15 changed files with 279 additions and 112 deletions

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
try:
from elasticsearch import Elasticsearch
@@ -136,8 +136,13 @@ class ElasticsearchDB(BaseVectorDB):
self.client.indices.refresh(index=self._get_index())
def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
self,
input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
query contents from vector data base based on vector similarity
@@ -150,8 +155,11 @@ class ElasticsearchDB(BaseVectorDB):
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool
:return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[Tuple[str,str,str]]
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
"""
if skip_embedding:
query_vector = input_query
@@ -175,14 +183,17 @@ class ElasticsearchDB(BaseVectorDB):
_source = ["text", "metadata.url", "metadata.doc_id"]
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
docs = response["hits"]["hits"]
contents = []
contexts = []
for doc in docs:
context = doc["_source"]["text"]
metadata = doc["_source"]["metadata"]
source = metadata["url"]
doc_id = metadata["doc_id"]
contents.append(tuple((context, source, doc_id)))
return contents
if citations:
metadata = doc["_source"]["metadata"]
source = metadata["url"]
doc_id = metadata["doc_id"]
contexts.append(tuple((context, source, doc_id)))
else:
contexts.append(context)
return contexts
def set_collection_name(self, name: str):
"""