[Feature] Update db.query to return source of context (#831)

This commit is contained in:
Deven Patel
2023-10-25 22:20:32 -07:00
committed by GitHub
parent a27eeb3255
commit d77e8da3f3
13 changed files with 195 additions and 73 deletions

View File

@@ -1,7 +1,7 @@
import copy
import os
import uuid
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
try:
from qdrant_client import QdrantClient
@@ -160,7 +160,9 @@ class QdrantDB(BaseVectorDB):
),
)
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
) -> List[Tuple[str, str, str]]:
"""
query contents from vector database based on vector similarity
:param input_query: list of query string
@@ -172,8 +174,8 @@ class QdrantDB(BaseVectorDB):
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not
:type skip_embedding: bool
:return: Database contents that are the result of the query
:rtype: List[str]
:return: The context of the document that matched your query, url of the source, doc_id
:rtype: List[Tuple[str,str,str]]
"""
if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0]
@@ -199,9 +201,14 @@ class QdrantDB(BaseVectorDB):
query_vector=query_vector,
limit=n_results,
)
response = []
for result in results:
response.append(result.payload.get("text", ""))
context = result.payload["text"]
metadata = result.payload["metadata"]
source = metadata["url"]
doc_id = metadata["doc_id"]
response.append(tuple((context, source, doc_id)))
return response
def count(self) -> int:
@@ -211,3 +218,15 @@ class QdrantDB(BaseVectorDB):
def reset(self):
self.client.delete_collection(collection_name=self.collection_name)
self._initialize()
def set_collection_name(self, name: str):
"""
Set the name of the collection. A collection is an isolated space for vectors.
:param name: Name of the collection.
:type name: str
"""
if not isinstance(name, str):
raise TypeError("Collection name must be a string")
self.config.collection_name = name
self.collection_name = self._get_or_create_collection()