[Feature] Update db.query to return source of context (#831)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import pinecone
|
||||
@@ -118,7 +118,9 @@ class PineconeDB(BaseVectorDB):
|
||||
for i in range(0, len(docs), self.BATCH_SIZE):
|
||||
self.client.upsert(docs[i : i + self.BATCH_SIZE])
|
||||
|
||||
def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
|
||||
def query(
|
||||
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
query contents from vector database based on vector similarity
|
||||
:param input_query: list of query string
|
||||
@@ -129,16 +131,22 @@ class PineconeDB(BaseVectorDB):
|
||||
:type where: Dict[str, any]
|
||||
:param skip_embedding: Optional. if True, input_query is already embedded
|
||||
:type skip_embedding: bool
|
||||
:return: Database contents that are the result of the query
|
||||
:rtype: List[str]
|
||||
:return: The content of the document that matched your query, url of the source, doc_id
|
||||
:rtype: List[Tuple[str,str,str]]
|
||||
"""
|
||||
if not skip_embedding:
|
||||
query_vector = self.embedder.embedding_fn([input_query])[0]
|
||||
else:
|
||||
query_vector = input_query
|
||||
contents = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
||||
embeddings = list(map(lambda content: content["metadata"]["text"], contents["matches"]))
|
||||
return embeddings
|
||||
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
|
||||
contents = []
|
||||
for doc in data["matches"]:
|
||||
metadata = doc["metadata"]
|
||||
context = metadata["text"]
|
||||
source = metadata["url"]
|
||||
doc_id = metadata["doc_id"]
|
||||
contents.append(tuple((context, source, doc_id)))
|
||||
return contents
|
||||
|
||||
def set_collection_name(self, name: str):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user