[Feature] Update db.query to return source of context (#831)

This commit is contained in:
Deven Patel
2023-10-25 22:20:32 -07:00
committed by GitHub
parent a27eeb3255
commit d77e8da3f3
13 changed files with 195 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
try:
from elasticsearch import Elasticsearch
@@ -135,7 +135,9 @@ class ElasticsearchDB(BaseVectorDB):
bulk(self.client, docs)
self.client.indices.refresh(index=self._get_index())
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
"""
query contents from vector data base based on vector similarity
@@ -147,8 +149,9 @@ class ElasticsearchDB(BaseVectorDB):
:type where: Dict[str, any]
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool
:return: Database contents that are the result of the query
:rtype: List[str]
:return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[Tuple[str,str,str]]
"""
if skip_embedding:
query_vector = input_query
@@ -156,6 +159,7 @@ class ElasticsearchDB(BaseVectorDB):
input_query_vector = self.embedder.embedding_fn(input_query)
query_vector = input_query_vector[0]
# `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`
query = {
"script_score": {
"query": {"bool": {"must": [{"exists": {"field": "text"}}]}},
@@ -167,11 +171,17 @@ class ElasticsearchDB(BaseVectorDB):
}
if "app_id" in where:
app_id = where["app_id"]
query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}]
_source = ["text"]
query["script_score"]["query"] = {"match": {"metadata.app_id": app_id}}
_source = ["text", "metadata.url", "metadata.doc_id"]
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
docs = response["hits"]["hits"]
contents = [doc["_source"]["text"] for doc in docs]
contents = []
for doc in docs:
context = doc["_source"]["text"]
metadata = doc["_source"]["metadata"]
source = metadata["url"]
doc_id = metadata["doc_id"]
contents.append(tuple((context, source, doc_id)))
return contents
def set_collection_name(self, name: str):