[Refactor] Improve logging package wide (#1315)
This commit is contained in:
@@ -6,6 +6,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AnthropicLlm(BaseLlm):
|
||||
@@ -26,7 +28,7 @@ class AnthropicLlm(BaseLlm):
|
||||
)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logging.warning("Config option `max_tokens` is not supported by this model.")
|
||||
logger.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@ class AWSBedrockLlm(BaseLlm):
|
||||
}
|
||||
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
callbacks = [StreamingStdOutCallbackHandler()]
|
||||
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
|
||||
|
||||
@@ -5,6 +5,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AzureOpenAILlm(BaseLlm):
|
||||
@@ -31,7 +33,7 @@ class AzureOpenAILlm(BaseLlm):
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
logger.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ChatHistory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLlm(JSONSerializable):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
@@ -108,7 +110,7 @@ class BaseLlm(JSONSerializable):
|
||||
)
|
||||
else:
|
||||
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
|
||||
)
|
||||
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
|
||||
@@ -159,7 +161,7 @@ class BaseLlm(JSONSerializable):
|
||||
'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
search = DuckDuckGoSearchRun()
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
logger.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
@staticmethod
|
||||
@@ -175,7 +177,7 @@ class BaseLlm(JSONSerializable):
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
logger.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
|
||||
"""
|
||||
@@ -214,13 +216,13 @@ class BaseLlm(JSONSerializable):
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
logger.info(f"Prompt: {prompt}")
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
logger.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
return self._stream_response(answer)
|
||||
@@ -270,14 +272,14 @@ class BaseLlm(JSONSerializable):
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
logger.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
logger.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
|
||||
@@ -10,6 +10,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleLlm(BaseLlm):
|
||||
@@ -36,7 +38,7 @@ class GoogleLlm(BaseLlm):
|
||||
|
||||
def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
|
||||
model_name = self.config.model or "gemini-pro"
|
||||
logging.info(f"Using Google LLM model: {model_name}")
|
||||
logger.info(f"Using Google LLM model: {model_name}")
|
||||
model = genai.GenerativeModel(model_name=model_name)
|
||||
|
||||
generation_config_params = {
|
||||
|
||||
@@ -11,6 +11,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class HuggingFaceLlm(BaseLlm):
|
||||
@@ -58,7 +60,7 @@ class HuggingFaceLlm(BaseLlm):
|
||||
raise ValueError("`top_p` must be > 0.0 and < 1.0")
|
||||
|
||||
model = config.model
|
||||
logging.info(f"Using HuggingFaceHub with model {model}")
|
||||
logger.info(f"Using HuggingFaceHub with model {model}")
|
||||
llm = HuggingFaceHub(
|
||||
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
|
||||
repo_id=model,
|
||||
|
||||
@@ -65,7 +65,8 @@ class OpenAILlm(BaseLlm):
|
||||
messages: list[BaseMessage],
|
||||
) -> str:
|
||||
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from langchain_core.utils.function_calling import \
|
||||
convert_to_openai_tool
|
||||
|
||||
openai_tools = [convert_to_openai_tool(tools)]
|
||||
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())
|
||||
|
||||
@@ -9,6 +9,8 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class VertexAILlm(BaseLlm):
|
||||
@@ -28,7 +30,7 @@ class VertexAILlm(BaseLlm):
|
||||
@staticmethod
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
logger.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user