From 0f9a10c598ebeb0e6988175ec3a6996a58ee9ce1 Mon Sep 17 00:00:00 2001 From: cachho Date: Tue, 12 Sep 2023 18:03:58 +0200 Subject: [PATCH] fix: use template from tempory `LlmConfig` (#590) --- embedchain/llm/base.py | 112 ++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 46 deletions(-) diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index 486c66e4..cc0179ee 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -174,27 +174,37 @@ class BaseLlm(JSONSerializable): :return: The answer to the query or the dry run result :rtype: str """ - query_config = config or self.config + try: + if config: + # A config instance passed to this method will only be applied temporarily, for one call. + # So we will save the previous config and restore it at the end of the execution. + # For this we use the serializer. + prev_config = self.config.serialize() + self.config = config - if self.is_docs_site_instance: - query_config.template = DOCS_SITE_PROMPT_TEMPLATE - query_config.number_documents = 5 - k = {} - if self.online: - k["web_search_result"] = self.access_search_and_get_results(input_query) - prompt = self.generate_prompt(input_query, contexts, **k) - logging.info(f"Prompt: {prompt}") + if self.is_docs_site_instance: + self.config.template = DOCS_SITE_PROMPT_TEMPLATE + self.config.number_documents = 5 + k = {} + if self.online: + k["web_search_result"] = self.access_search_and_get_results(input_query) + prompt = self.generate_prompt(input_query, contexts, **k) + logging.info(f"Prompt: {prompt}") - if dry_run: - return prompt + if dry_run: + return prompt - answer = self.get_answer_from_llm(prompt) + answer = self.get_answer_from_llm(prompt) - if isinstance(answer, str): - logging.info(f"Answer: {answer}") - return answer - else: - return self._stream_query_response(answer) + if isinstance(answer, str): + logging.info(f"Answer: {answer}") + return answer + else: + return self._stream_query_response(answer) + finally: + if config: + # Restore previous config + self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config) def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False): """ @@ -217,39 +227,49 @@ class BaseLlm(JSONSerializable): :return: The answer to the query or the dry run result :rtype: str """ - query_config = config or self.config + try: + if config: + # A config instance passed to this method will only be applied temporarily, for one call. + # So we will save the previous config and restore it at the end of the execution. + # For this we use the serializer. + prev_config = self.config.serialize() + self.config = config - if self.is_docs_site_instance: - query_config.template = DOCS_SITE_PROMPT_TEMPLATE - query_config.number_documents = 5 - k = {} - if self.online: - k["web_search_result"] = self.access_search_and_get_results(input_query) + if self.is_docs_site_instance: + self.config.template = DOCS_SITE_PROMPT_TEMPLATE + self.config.number_documents = 5 + k = {} + if self.online: + k["web_search_result"] = self.access_search_and_get_results(input_query) - self.update_history() - - prompt = self.generate_prompt(input_query, contexts, **k) - logging.info(f"Prompt: {prompt}") - - if dry_run: - return prompt - - answer = self.get_answer_from_llm(prompt) - - self.memory.chat_memory.add_user_message(input_query) - - if isinstance(answer, str): - self.memory.chat_memory.add_ai_message(answer) - logging.info(f"Answer: {answer}") - - # NOTE: Adding to history before and after. This could be seen as redundant. - # If we change it, we have to change the tests (no big deal). self.update_history() - return answer - else: - # this is a streamed response and needs to be handled differently. - return self._stream_chat_response(answer) + prompt = self.generate_prompt(input_query, contexts, **k) + logging.info(f"Prompt: {prompt}") + + if dry_run: + return prompt + + answer = self.get_answer_from_llm(prompt) + + self.memory.chat_memory.add_user_message(input_query) + + if isinstance(answer, str): + self.memory.chat_memory.add_ai_message(answer) + logging.info(f"Answer: {answer}") + + # NOTE: Adding to history before and after. This could be seen as redundant. + # If we change it, we have to change the tests (no big deal). + self.update_history() + + return answer + else: + # this is a streamed response and needs to be handled differently. + return self._stream_chat_response(answer) + finally: + if config: + # Restore previous config + self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config) @staticmethod def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: