Show details for query tokens (#1392)

This commit is contained in:
Dev Khant
2024-07-05 00:10:56 +05:30
committed by GitHub
parent ea09b5f7f0
commit 4880557d51
25 changed files with 1825 additions and 517 deletions

View File

@@ -6,9 +6,7 @@ from typing import Any, Optional, Union
from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig
@@ -18,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.utils.misc import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
@@ -478,7 +475,7 @@ class EmbedChain(JSONSerializable):
where: Optional[dict] = None,
citations: bool = False,
**kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str]:
) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -501,7 +498,9 @@ class EmbedChain(JSONSerializable):
:type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
:rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
"""
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -524,17 +523,29 @@ class EmbedChain(JSONSerializable):
dry_run=dry_run,
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
if self.llm.config.token_usage:
answer, token_info = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# Send anonymous telemetry
self.telemetry.capture(event_name="query", properties=self._telemetry_props)
if citations:
if self.llm.config.token_usage:
return {"answer": answer, "contexts": contexts, "usage": token_info}
return answer, contexts
else:
return answer
if self.llm.config.token_usage:
return {"answer": answer, "usage": token_info}
logger.warning(
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
)
return answer
def chat(
self,
@@ -545,7 +556,7 @@ class EmbedChain(JSONSerializable):
where: Optional[dict[str, str]] = None,
citations: bool = False,
**kwargs: dict[str, Any],
) -> Union[tuple[str, list[tuple[str, dict]]], str]:
) -> Union[tuple[str, list[tuple[str, dict]]], str, dict[str, Any]]:
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -572,7 +583,9 @@ class EmbedChain(JSONSerializable):
:type kwargs: dict[str, Any]
:return: The answer to the query, with citations if the citation flag is True
or the dry run result
:rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
:rtype: str, if citations is False and token_usage is False, otherwise if citations is true then
tuple[str, list[tuple[str,str,str]]] and if token_usage is true then
tuple[str, list[tuple[str,str,str]], dict[str, Any]]
"""
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -600,9 +613,14 @@ class EmbedChain(JSONSerializable):
)
else:
logger.debug("Cache disabled. Running chat without cache.")
answer = self.llm.chat(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
if self.llm.config.token_usage:
answer, token_info = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
else:
answer = self.llm.query(
input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
)
# add conversation in memory
self.llm.add_history(self.config.id, input_query, answer, session_id=session_id)
@@ -611,9 +629,16 @@ class EmbedChain(JSONSerializable):
self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
if citations:
if self.llm.config.token_usage:
return {"answer": answer, "contexts": contexts, "usage": token_info}
return answer, contexts
else:
return answer
if self.llm.config.token_usage:
return {"answer": answer, "usage": token_info}
logger.warning(
"Starting from v0.1.125 the return type of query method will be changed to tuple containing `answer`."
)
return answer
def search(self, query, num_documents=3, where=None, raw_filter=None, namespace=None):
"""