[Refactor] Improve logging package wide (#1315)

This commit is contained in:
Deshraj Yadav
2024-03-13 17:13:30 -07:00
committed by GitHub
parent ef69c91b60
commit 3616eaadb4
54 changed files with 263 additions and 231 deletions

View File

@@ -6,6 +6,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class AnthropicLlm(BaseLlm):
@@ -26,7 +28,7 @@ class AnthropicLlm(BaseLlm):
)
if config.max_tokens and config.max_tokens != 1000:
logging.warning("Config option `max_tokens` is not supported by this model.")
logger.warning("Config option `max_tokens` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)

View File

@@ -38,7 +38,8 @@ class AWSBedrockLlm(BaseLlm):
}
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
callbacks = [StreamingStdOutCallbackHandler()]
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)

View File

@@ -5,6 +5,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class AzureOpenAILlm(BaseLlm):
@@ -31,7 +33,7 @@ class AzureOpenAILlm(BaseLlm):
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
logger.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)

View File

@@ -12,6 +12,8 @@ from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
logger = logging.getLogger(__name__)
class BaseLlm(JSONSerializable):
def __init__(self, config: Optional[BaseLlmConfig] = None):
@@ -108,7 +110,7 @@ class BaseLlm(JSONSerializable):
)
else:
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
logging.warning(
logger.warning(
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
)
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
@@ -159,7 +161,7 @@ class BaseLlm(JSONSerializable):
'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None
search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
logger.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
@staticmethod
@@ -175,7 +177,7 @@ class BaseLlm(JSONSerializable):
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
logger.info(f"Answer: {streamed_answer}")
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
"""
@@ -214,13 +216,13 @@ class BaseLlm(JSONSerializable):
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
logger.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
logger.info(f"Answer: {answer}")
return answer
else:
return self._stream_response(answer)
@@ -270,14 +272,14 @@ class BaseLlm(JSONSerializable):
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
logger.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
logger.info(f"Answer: {answer}")
return answer
else:
# this is a streamed response and needs to be handled differently.

View File

@@ -10,6 +10,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class GoogleLlm(BaseLlm):
@@ -36,7 +38,7 @@ class GoogleLlm(BaseLlm):
def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
model_name = self.config.model or "gemini-pro"
logging.info(f"Using Google LLM model: {model_name}")
logger.info(f"Using Google LLM model: {model_name}")
model = genai.GenerativeModel(model_name=model_name)
generation_config_params = {

View File

@@ -11,6 +11,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class HuggingFaceLlm(BaseLlm):
@@ -58,7 +60,7 @@ class HuggingFaceLlm(BaseLlm):
raise ValueError("`top_p` must be > 0.0 and < 1.0")
model = config.model
logging.info(f"Using HuggingFaceHub with model {model}")
logger.info(f"Using HuggingFaceHub with model {model}")
llm = HuggingFaceHub(
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
repo_id=model,

View File

@@ -65,7 +65,8 @@ class OpenAILlm(BaseLlm):
messages: list[BaseMessage],
) -> str:
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.function_calling import \
convert_to_openai_tool
openai_tools = [convert_to_openai_tool(tools)]
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

View File

@@ -9,6 +9,8 @@ from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
logger = logging.getLogger(__name__)
@register_deserializable
class VertexAILlm(BaseLlm):
@@ -28,7 +30,7 @@ class VertexAILlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
logger.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)