[Feature] Add citations flag in query and chat functions of App to return context along with the answer (#859)

This commit is contained in:
Deven Patel
2023-11-01 13:06:28 -07:00
committed by GitHub
parent 5022c1ae29
commit 930280f4ce
15 changed files with 279 additions and 112 deletions

View File

@@ -4,7 +4,7 @@ import logging
import os import os
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
@@ -438,7 +438,9 @@ class EmbedChain(JSONSerializable):
) )
] ]
def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]: def retrieve_from_database(
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query Gets relevant doc based on the query
@@ -449,6 +451,8 @@ class EmbedChain(JSONSerializable):
:type config: Optional[BaseLlmConfig], optional :type config: Optional[BaseLlmConfig], optional
:param where: A dictionary of key-value pairs to filter the database results, defaults to None :param where: A dictionary of key-value pairs to filter the database results, defaults to None
:type where: _type_, optional :type where: _type_, optional
:param citations: A boolean to indicate if db should fetch citation source
:type citations: bool
:return: List of contents of the document that matched your query :return: List of contents of the document that matched your query
:rtype: List[str] :rtype: List[str]
""" """
@@ -478,14 +482,19 @@ class EmbedChain(JSONSerializable):
n_results=query_config.number_documents, n_results=query_config.number_documents,
where=where, where=where,
skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"), skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
citations=citations,
) )
if len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts = list(map(lambda x: x[0], contexts))
return contexts return contexts
def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str: def query(
self,
input_query: str,
config: BaseLlmConfig = None,
dry_run=False,
where: Optional[Dict] = None,
**kwargs: Dict[str, Any],
) -> Union[Tuple[str, List[Tuple[str, str, str]]], str]:
""" """
Queries the vector database based on the given input query. Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an Gets relevant doc based on the query and then passes it to an
@@ -501,15 +510,31 @@ class EmbedChain(JSONSerializable):
:type dry_run: bool, optional :type dry_run: bool, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional :type where: Optional[Dict[str, str]], optional
:return: The answer to the query or the dry run result :param kwargs: To read more params for the query function. Ex. we use citations boolean
:rtype: str param to return context along with the answer
:type kwargs: Dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
""" """
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) citations = kwargs.get("citations", False)
answer = self.llm.query(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
else:
contexts_data_for_llm_query = contexts
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# Send anonymous telemetry # Send anonymous telemetry
self.telemetry.capture(event_name="query", properties=self._telemetry_props) self.telemetry.capture(event_name="query", properties=self._telemetry_props)
return answer
if citations:
return answer, contexts
else:
return answer
def chat( def chat(
self, self,
@@ -517,6 +542,7 @@ class EmbedChain(JSONSerializable):
config: Optional[BaseLlmConfig] = None, config: Optional[BaseLlmConfig] = None,
dry_run=False, dry_run=False,
where: Optional[Dict[str, str]] = None, where: Optional[Dict[str, str]] = None,
**kwargs: Dict[str, Any],
) -> str: ) -> str:
""" """
Queries the vector database on the given input query. Queries the vector database on the given input query.
@@ -535,15 +561,31 @@ class EmbedChain(JSONSerializable):
:type dry_run: bool, optional :type dry_run: bool, optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Optional[Dict[str, str]], optional :type where: Optional[Dict[str, str]], optional
:return: The answer to the query or the dry run result :param kwargs: To read more params for the query function. Ex. we use citations boolean
:rtype: str param to return context along with the answer
:type kwargs: Dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
""" """
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where) citations = kwargs.get("citations", False)
answer = self.llm.chat(input_query=input_query, contexts=contexts, config=config, dry_run=dry_run) contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
else:
contexts_data_for_llm_query = contexts
answer = self.llm.chat(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# Send anonymous telemetry # Send anonymous telemetry
self.telemetry.capture(event_name="chat", properties=self._telemetry_props) self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
return answer if citations:
return answer, contexts
else:
return answer
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -234,6 +234,7 @@ class Pipeline(EmbedChain):
n_results=num_documents, n_results=num_documents,
where=where, where=where,
skip_embedding=False, skip_embedding=False,
citations=True,
) )
result = [] result = []
for c in context: for c in context:

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from chromadb import Collection, QueryResult from chromadb import Collection, QueryResult
from langchain.docstore.document import Document from langchain.docstore.document import Document
@@ -192,8 +192,13 @@ class ChromaDB(BaseVectorDB):
] ]
def query( def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool self,
) -> List[Tuple[str, str, str]]: input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
Query contents from vector database based on vector similarity Query contents from vector database based on vector similarity
@@ -205,9 +210,12 @@ class ChromaDB(BaseVectorDB):
:type where: Dict[str, Any] :type where: Dict[str, Any]
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool :type skip_embedding: bool
:param citations: we use citations boolean param to return context along with the answer.
:type citations: bool, default is False.
:raises InvalidDimensionException: Dimensions do not match. :raises InvalidDimensionException: Dimensions do not match.
:return: The content of the document that matched your query, url of the source, doc_id :return: The content of the document that matched your query,
:rtype: List[Tuple[str,str,str]] along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
""" """
try: try:
if skip_embedding: if skip_embedding:
@@ -236,10 +244,13 @@ class ChromaDB(BaseVectorDB):
contexts = [] contexts = []
for result in results_formatted: for result in results_formatted:
context = result[0].page_content context = result[0].page_content
metadata = result[0].metadata if citations:
source = metadata["url"] metadata = result[0].metadata
doc_id = metadata["doc_id"] source = metadata["url"]
contexts.append((context, source, doc_id)) doc_id = metadata["doc_id"]
contexts.append((context, source, doc_id))
else:
contexts.append(context)
return contexts return contexts
def set_collection_name(self, name: str): def set_collection_name(self, name: str):

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
try: try:
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
@@ -136,8 +136,13 @@ class ElasticsearchDB(BaseVectorDB):
self.client.indices.refresh(index=self._get_index()) self.client.indices.refresh(index=self._get_index())
def query( def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool self,
) -> List[Tuple[str, str, str]]: input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
@@ -150,8 +155,11 @@ class ElasticsearchDB(BaseVectorDB):
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool :type skip_embedding: bool
:return: The context of the document that matched your query, url of the source, doc_id :return: The context of the document that matched your query, url of the source, doc_id
:param citations: we use citations boolean param to return context along with the answer.
:rtype: List[Tuple[str,str,str]] :type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
""" """
if skip_embedding: if skip_embedding:
query_vector = input_query query_vector = input_query
@@ -175,14 +183,17 @@ class ElasticsearchDB(BaseVectorDB):
_source = ["text", "metadata.url", "metadata.doc_id"] _source = ["text", "metadata.url", "metadata.doc_id"]
response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results) response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
docs = response["hits"]["hits"] docs = response["hits"]["hits"]
contents = [] contexts = []
for doc in docs: for doc in docs:
context = doc["_source"]["text"] context = doc["_source"]["text"]
metadata = doc["_source"]["metadata"] if citations:
source = metadata["url"] metadata = doc["_source"]["metadata"]
doc_id = metadata["doc_id"] source = metadata["url"]
contents.append(tuple((context, source, doc_id))) doc_id = metadata["doc_id"]
return contents contexts.append(tuple((context, source, doc_id)))
else:
contexts.append(context)
return contexts
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple, Union
try: try:
from opensearchpy import OpenSearch from opensearchpy import OpenSearch
@@ -146,8 +146,13 @@ class OpenSearchDB(BaseVectorDB):
self.client.indices.refresh(index=self._get_index()) self.client.indices.refresh(index=self._get_index())
def query( def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool self,
) -> List[Tuple[str, str, str]]: input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector data base based on vector similarity query contents from vector data base based on vector similarity
@@ -159,8 +164,11 @@ class OpenSearchDB(BaseVectorDB):
:type where: Dict[str, any] :type where: Dict[str, any]
:param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded. :param skip_embedding: Optional. If True, then the input_query is assumed to be already embedded.
:type skip_embedding: bool :type skip_embedding: bool
:return: The content of the document that matched your query, url of the source, doc_id :param citations: we use citations boolean param to return context along with the answer.
:rtype: List[Tuple[str,str,str]] :type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
""" """
# TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists # TODO(rupeshbansal, deshraj): Add support for skip embeddings here if already exists
embeddings = OpenAIEmbeddings() embeddings = OpenAIEmbeddings()
@@ -188,13 +196,16 @@ class OpenSearchDB(BaseVectorDB):
k=n_results, k=n_results,
) )
contents = [] contexts = []
for doc in docs: for doc in docs:
context = doc.page_content context = doc.page_content
source = doc.metadata["url"] if citations:
doc_id = doc.metadata["doc_id"] source = doc.metadata["url"]
contents.append(tuple((context, source, doc_id))) doc_id = doc.metadata["doc_id"]
return contents contexts.append(tuple((context, source, doc_id)))
else:
contexts.append(context)
return contexts
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
try: try:
import pinecone import pinecone
@@ -119,8 +119,13 @@ class PineconeDB(BaseVectorDB):
self.client.upsert(docs[i : i + self.BATCH_SIZE]) self.client.upsert(docs[i : i + self.BATCH_SIZE])
def query( def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool self,
) -> List[Tuple[str, str, str]]: input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -131,22 +136,28 @@ class PineconeDB(BaseVectorDB):
:type where: Dict[str, any] :type where: Dict[str, any]
:param skip_embedding: Optional. if True, input_query is already embedded :param skip_embedding: Optional. if True, input_query is already embedded
:type skip_embedding: bool :type skip_embedding: bool
:return: The content of the document that matched your query, url of the source, doc_id :param citations: we use citations boolean param to return context along with the answer.
:rtype: List[Tuple[str,str,str]] :type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
""" """
if not skip_embedding: if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
else: else:
query_vector = input_query query_vector = input_query
data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True) data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
contents = [] contexts = []
for doc in data["matches"]: for doc in data["matches"]:
metadata = doc["metadata"] metadata = doc["metadata"]
context = metadata["text"] context = metadata["text"]
source = metadata["url"] if citations:
doc_id = metadata["doc_id"] source = metadata["url"]
contents.append(tuple((context, source, doc_id))) doc_id = metadata["doc_id"]
return contents contexts.append(tuple((context, source, doc_id)))
else:
contexts.append(context)
return contexts
def set_collection_name(self, name: str): def set_collection_name(self, name: str):
""" """

View File

@@ -1,7 +1,7 @@
import copy import copy
import os import os
import uuid import uuid
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
try: try:
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
@@ -161,8 +161,13 @@ class QdrantDB(BaseVectorDB):
) )
def query( def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool self,
) -> List[Tuple[str, str, str]]: input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -174,8 +179,11 @@ class QdrantDB(BaseVectorDB):
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not generated or not
:type skip_embedding: bool :type skip_embedding: bool
:return: The context of the document that matched your query, url of the source, doc_id :param citations: we use citations boolean param to return context along with the answer.
:rtype: List[Tuple[str,str,str]] :type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
""" """
if not skip_embedding: if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
@@ -202,14 +210,17 @@ class QdrantDB(BaseVectorDB):
limit=n_results, limit=n_results,
) )
response = [] contexts = []
for result in results: for result in results:
context = result.payload["text"] context = result.payload["text"]
metadata = result.payload["metadata"] if citations:
source = metadata["url"] metadata = result.payload["metadata"]
doc_id = metadata["doc_id"] source = metadata["url"]
response.append(tuple((context, source, doc_id))) doc_id = metadata["doc_id"]
return response contexts.append(tuple((context, source, doc_id)))
else:
contexts.append(context)
return contexts
def count(self) -> int: def count(self) -> int:
response = self.client.get_collection(collection_name=self.collection_name) response = self.client.get_collection(collection_name=self.collection_name)

View File

@@ -1,6 +1,6 @@
import copy import copy
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
try: try:
import weaviate import weaviate
@@ -58,10 +58,14 @@ class WeaviateDB(BaseVectorDB):
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.") raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
self.index_name = self._get_index_name() self.index_name = self._get_index_name()
self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"} self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
if not self.client.schema.exists(self.index_name): if not self.client.schema.exists(self.index_name):
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
# The none vectorizer is crucial as we have our own custom embedding function # The none vectorizer is crucial as we have our own custom embedding function
"""
TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
"""
class_obj = { class_obj = {
"classes": [ "classes": [
{ {
@@ -106,10 +110,6 @@ class WeaviateDB(BaseVectorDB):
"name": "app_id", "name": "app_id",
"dataType": ["text"], "dataType": ["text"],
}, },
{
"name": "text",
"dataType": ["text"],
},
], ],
}, },
] ]
@@ -195,8 +195,13 @@ class WeaviateDB(BaseVectorDB):
batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata") batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
def query( def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool self,
) -> List[Tuple[str, str, str]]: input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
query contents from vector database based on vector similarity query contents from vector database based on vector similarity
:param input_query: list of query string :param input_query: list of query string
@@ -208,15 +213,23 @@ class WeaviateDB(BaseVectorDB):
:param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
generated or not generated or not
:type skip_embedding: bool :type skip_embedding: bool
:return: The context of the document that matched your query, url of the source, doc_id :param citations: we use citations boolean param to return context along with the answer.
:rtype: List[Tuple[str,str,str]] :type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
""" """
if not skip_embedding: if not skip_embedding:
query_vector = self.embedder.embedding_fn([input_query])[0] query_vector = self.embedder.embedding_fn([input_query])[0]
else: else:
query_vector = input_query query_vector = input_query
keys = set(where.keys() if where is not None else set()) keys = set(where.keys() if where is not None else set())
data_fields = ["text"] data_fields = ["text"]
if citations:
data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
if len(keys.intersection(self.metadata_keys)) != 0: if len(keys.intersection(self.metadata_keys)) != 0:
weaviate_where_operands = [] weaviate_where_operands = []
for key in keys: for key in keys:
@@ -247,7 +260,18 @@ class WeaviateDB(BaseVectorDB):
.with_limit(n_results) .with_limit(n_results)
.do() .do()
) )
contexts = results["data"]["Get"].get(self.index_name)
docs = results["data"]["Get"].get(self.index_name)
contexts = []
for doc in docs:
context = doc["text"]
if citations:
metadata = doc["metadata"][0]
source = metadata["url"]
doc_id = metadata["doc_id"]
contexts.append((context, source, doc_id))
else:
contexts.append(context)
return contexts return contexts
def set_collection_name(self, name: str): def set_collection_name(self, name: str):

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
from embedchain.config import ZillizDBConfig from embedchain.config import ZillizDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
@@ -127,8 +127,13 @@ class ZillizVectorDB(BaseVectorDB):
self.client.flush(self.config.collection_name) self.client.flush(self.config.collection_name)
def query( def query(
self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool self,
) -> List[Tuple[str, str, str]]: input_query: List[str],
n_results: int,
where: Dict[str, any],
skip_embedding: bool,
citations: bool = False,
) -> Union[List[Tuple[str, str, str]], List[str]]:
""" """
Query contents from vector data base based on vector similarity Query contents from vector data base based on vector similarity
@@ -139,8 +144,11 @@ class ZillizVectorDB(BaseVectorDB):
:param where: to filter data :param where: to filter data
:type where: str :type where: str
:raises InvalidDimensionException: Dimensions do not match. :raises InvalidDimensionException: Dimensions do not match.
:return: The context of the document that matched your query, url of the source, doc_id :param citations: we use citations boolean param to return context along with the answer.
:rtype: List[Tuple[str,str,str]] :type citations: bool, default is False.
:return: The content of the document that matched your query,
along with url of the source and doc_id (if citations flag is true)
:rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
""" """
if self.collection.is_empty: if self.collection.is_empty:
@@ -170,14 +178,17 @@ class ZillizVectorDB(BaseVectorDB):
output_fields=output_fields, output_fields=output_fields,
) )
doc_list = [] contexts = []
for query in query_result: for query in query_result:
data = query[0]["entity"] data = query[0]["entity"]
context = data["text"] context = data["text"]
source = data["url"] if citations:
doc_id = data["doc_id"] source = data["url"]
doc_list.append(tuple((context, source, doc_id))) doc_id = data["doc_id"]
return doc_list contexts.append(tuple((context, source, doc_id)))
else:
contexts.append(context)
return contexts
def count(self) -> int: def count(self) -> int:
""" """

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.0.88" version = "0.0.89"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -163,10 +163,12 @@ def test_chroma_db_collection_add_with_skip_embedding(app_with_settings):
assert data == expected_value assert data == expected_value
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True) data_without_citations = app_with_settings.db.query(
expected_value = [("document", "url_1", "doc_id_1")] input_query=[0, 0, 0], where={}, n_results=1, skip_embedding=True
)
expected_value_without_citations = ["document"]
assert data_without_citations == expected_value_without_citations
assert data == expected_value
app_with_settings.db.reset() app_with_settings.db.reset()
@@ -326,8 +328,16 @@ def test_chroma_db_collection_query(app_with_settings):
assert app_with_settings.db.count() == 2 assert app_with_settings.db.count() == 2
data = app_with_settings.db.query(input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True) data_without_citations = app_with_settings.db.query(
expected_value = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")] input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True
)
expected_value_without_citations = ["document", "document2"]
assert data_without_citations == expected_value_without_citations
data_with_citations = app_with_settings.db.query(
input_query=[0, 0, 0], where={}, n_results=2, skip_embedding=True, citations=True
)
expected_value_with_citations = [("document", "url_1", "doc_id_1"), ("document2", "url_2", "doc_id_2")]
assert data_with_citations == expected_value_with_citations
assert data == expected_value
app_with_settings.db.reset() app_with_settings.db.reset()

View File

@@ -60,12 +60,16 @@ class TestEsDB(unittest.TestCase):
# Query the database for the documents that are most similar to the query "This is a document". # Query the database for the documents that are most similar to the query "This is a document".
query = ["This is a document"] query = ["This is a document"]
results = self.db.query(query, n_results=2, where={}, skip_embedding=False) results_without_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False)
expected_results_without_citations = ["This is a document.", "This is another document."]
self.assertEqual(results_without_citations, expected_results_without_citations)
# Assert that the results are correct. results_with_citations = self.db.query(query, n_results=2, where={}, skip_embedding=False, citations=True)
self.assertEqual( expected_results_with_citations = [
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")] ("This is a document.", "url_1", "doc_id_1"),
) ("This is another document.", "url_2", "doc_id_2"),
]
self.assertEqual(results_with_citations, expected_results_with_citations)
@patch("embedchain.vectordb.elasticsearch.Elasticsearch") @patch("embedchain.vectordb.elasticsearch.Elasticsearch")
def test_query_with_skip_embedding(self, mock_client): def test_query_with_skip_embedding(self, mock_client):
@@ -111,9 +115,7 @@ class TestEsDB(unittest.TestCase):
results = self.db.query(query, n_results=2, where={}, skip_embedding=True) results = self.db.query(query, n_results=2, where={}, skip_embedding=True)
# Assert that the results are correct. # Assert that the results are correct.
self.assertEqual( self.assertEqual(results, ["This is a document.", "This is another document."])
results, [("This is a document.", "url_1", "doc_id_1"), ("This is another document.", "url_2", "doc_id_2")]
)
def test_init_without_url(self): def test_init_without_url(self):
# Make sure it's not loaded from env # Make sure it's not loaded from env

View File

@@ -75,10 +75,6 @@ class TestWeaviateDb(unittest.TestCase):
"name": "app_id", "name": "app_id",
"dataType": ["text"], "dataType": ["text"],
}, },
{
"name": "text",
"dataType": ["text"],
},
], ],
}, },
] ]

View File

@@ -129,7 +129,7 @@ class TestZillizDBCollection:
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True) query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=True)
# Assert that MilvusClient.search was called with the correct parameters # Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_once_with( mock_search.assert_called_with(
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_text"], data=["query_text"],
limit=1, limit=1,
@@ -137,7 +137,20 @@ class TestZillizDBCollection:
) )
# Assert that the query result matches the expected result # Assert that the query result matches the expected result
assert query_result == [("result_doc", "url_1", "doc_id_1")] assert query_result == ["result_doc"]
query_result_with_citations = zilliz_db.query(
input_query=["query_text"], n_results=1, where={}, skip_embedding=True, citations=True
)
mock_search.assert_called_with(
collection_name=mock_config.collection_name,
data=["query_text"],
limit=1,
output_fields=["text", "url", "doc_id"],
)
assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]
@patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True) @patch("embedchain.vectordb.zilliz.MilvusClient", autospec=True)
@patch("embedchain.vectordb.zilliz.connections", autospec=True) @patch("embedchain.vectordb.zilliz.connections", autospec=True)
@@ -168,7 +181,7 @@ class TestZillizDBCollection:
query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False) query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={}, skip_embedding=False)
# Assert that MilvusClient.search was called with the correct parameters # Assert that MilvusClient.search was called with the correct parameters
mock_search.assert_called_once_with( mock_search.assert_called_with(
collection_name=mock_config.collection_name, collection_name=mock_config.collection_name,
data=["query_vector"], data=["query_vector"],
limit=1, limit=1,
@@ -176,4 +189,17 @@ class TestZillizDBCollection:
) )
# Assert that the query result matches the expected result # Assert that the query result matches the expected result
assert query_result == [("result_doc", "url_1", "doc_id_1")] assert query_result == ["result_doc"]
query_result_with_citations = zilliz_db.query(
input_query=["query_text"], n_results=1, where={}, skip_embedding=False, citations=True
)
mock_search.assert_called_with(
collection_name=mock_config.collection_name,
data=["query_vector"],
limit=1,
output_fields=["text", "url", "doc_id"],
)
assert query_result_with_citations == [("result_doc", "url_1", "doc_id_1")]