fix: use template from tempory LlmConfig (#590)

This commit is contained in:
cachho
2023-09-12 18:03:58 +02:00
committed by GitHub
parent 2bd6881361
commit 0f9a10c598

View File

@@ -174,27 +174,37 @@ class BaseLlm(JSONSerializable):
:return: The answer to the query or the dry run result
:rtype: str
"""
query_config = config or self.config
try:
if config:
# A config instance passed to this method will only be applied temporarily, for one call.
# So we will save the previous config and restore it at the end of the execution.
# For this we use the serializer.
prev_config = self.config.serialize()
self.config = config
if self.is_docs_site_instance:
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
query_config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer
else:
return self._stream_query_response(answer)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer
else:
return self._stream_query_response(answer)
finally:
if config:
# Restore previous config
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
"""
@@ -217,39 +227,49 @@ class BaseLlm(JSONSerializable):
:return: The answer to the query or the dry run result
:rtype: str
"""
query_config = config or self.config
try:
if config:
# A config instance passed to this method will only be applied temporarily, for one call.
# So we will save the previous config and restore it at the end of the execution.
# For this we use the serializer.
prev_config = self.config.serialize()
self.config = config
if self.is_docs_site_instance:
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
query_config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
self.update_history()
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
self.memory.chat_memory.add_user_message(input_query)
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
# NOTE: Adding to history before and after. This could be seen as redundant.
# If we change it, we have to change the tests (no big deal).
self.update_history()
return answer
else:
# this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
self.memory.chat_memory.add_user_message(input_query)
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
# NOTE: Adding to history before and after. This could be seen as redundant.
# If we change it, we have to change the tests (no big deal).
self.update_history()
return answer
else:
# this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer)
finally:
if config:
# Restore previous config
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: