Migrate from template to prompt arg while keeping backward compatibility (#1066)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -59,6 +60,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
self,
|
||||
number_documents: int = 3,
|
||||
template: Optional[Template] = None,
|
||||
prompt: Optional[Template] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 1000,
|
||||
@@ -80,8 +82,11 @@ class BaseLlmConfig(BaseConfig):
|
||||
context, defaults to 1
|
||||
:type number_documents: int, optional
|
||||
:param template: The `Template` instance to use as a template for
|
||||
prompt, defaults to None
|
||||
prompt, defaults to None (deprecated)
|
||||
:type template: Optional[Template], optional
|
||||
:param prompt: The `Template` instance to use as a template for
|
||||
prompt, defaults to None
|
||||
:type prompt: Optional[Template], optional
|
||||
:param model: Controls the OpenAI model used, defaults to None
|
||||
:type model: Optional[str], optional
|
||||
:param temperature: Controls the randomness of the model's output.
|
||||
@@ -106,8 +111,16 @@ class BaseLlmConfig(BaseConfig):
|
||||
contain $context and $query (and optionally $history)
|
||||
:raises ValueError: Stream is not boolean
|
||||
"""
|
||||
if template is None:
|
||||
template = DEFAULT_PROMPT_TEMPLATE
|
||||
if template is not None:
|
||||
logging.warning(
|
||||
"The `template` argument is deprecated and will be removed in a future version. "
|
||||
+ "Please use `prompt` instead."
|
||||
)
|
||||
if prompt is None:
|
||||
prompt = template
|
||||
|
||||
if prompt is None:
|
||||
prompt = DEFAULT_PROMPT_TEMPLATE
|
||||
|
||||
self.number_documents = number_documents
|
||||
self.temperature = temperature
|
||||
@@ -120,37 +133,37 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.callbacks = callbacks
|
||||
self.api_key = api_key
|
||||
|
||||
if type(template) is str:
|
||||
template = Template(template)
|
||||
if type(prompt) is str:
|
||||
prompt = Template(prompt)
|
||||
|
||||
if self.validate_template(template):
|
||||
self.template = template
|
||||
if self.validate_prompt(prompt):
|
||||
self.prompt = prompt
|
||||
else:
|
||||
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
|
||||
raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
|
||||
|
||||
if not isinstance(stream, bool):
|
||||
raise ValueError("`stream` should be bool")
|
||||
self.stream = stream
|
||||
self.where = where
|
||||
|
||||
def validate_template(self, template: Template) -> bool:
|
||||
def validate_prompt(self, prompt: Template) -> bool:
|
||||
"""
|
||||
validate the template
|
||||
validate the prompt
|
||||
|
||||
:param template: the template to validate
|
||||
:type template: Template
|
||||
:param prompt: the prompt to validate
|
||||
:type prompt: Template
|
||||
:return: valid (true) or invalid (false)
|
||||
:rtype: bool
|
||||
"""
|
||||
return re.search(query_re, template.template) and re.search(context_re, template.template)
|
||||
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
|
||||
|
||||
def _validate_template_history(self, template: Template) -> bool:
|
||||
def _validate_prompt_history(self, prompt: Template) -> bool:
|
||||
"""
|
||||
validate the template with history
|
||||
validate the prompt with history
|
||||
|
||||
:param template: the template to validate
|
||||
:type template: Template
|
||||
:param prompt: the prompt to validate
|
||||
:type prompt: Template
|
||||
:return: valid (true) or invalid (false)
|
||||
:rtype: bool
|
||||
"""
|
||||
return re.search(history_re, template.template)
|
||||
return re.search(history_re, prompt.template)
|
||||
|
||||
@@ -74,19 +74,19 @@ class BaseLlm(JSONSerializable):
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
|
||||
template_contains_history = self.config._validate_template_history(self.config.template)
|
||||
if template_contains_history:
|
||||
# Template contains history
|
||||
prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
|
||||
if prompt_contains_history:
|
||||
# Prompt contains history
|
||||
# If there is no history yet, we insert `- no history -`
|
||||
prompt = self.config.template.substitute(
|
||||
prompt = self.config.prompt.substitute(
|
||||
context=context_string, query=input_query, history=self.history or "- no history -"
|
||||
)
|
||||
elif self.history and not template_contains_history:
|
||||
# History is present, but not included in the template.
|
||||
# check if it's the default template without history
|
||||
elif self.history and not prompt_contains_history:
|
||||
# History is present, but not included in the prompt.
|
||||
# check if it's the default prompt without history
|
||||
if (
|
||||
not self.config._validate_template_history(self.config.template)
|
||||
and self.config.template.template == DEFAULT_PROMPT
|
||||
not self.config._validate_prompt_history(self.config.prompt)
|
||||
and self.config.prompt.template == DEFAULT_PROMPT
|
||||
):
|
||||
# swap in the template with history
|
||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||
@@ -95,12 +95,12 @@ class BaseLlm(JSONSerializable):
|
||||
else:
|
||||
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
|
||||
logging.warning(
|
||||
"Your bot contains a history, but template does not include `$history` key. History is ignored."
|
||||
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
|
||||
)
|
||||
prompt = self.config.template.substitute(context=context_string, query=input_query)
|
||||
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
# basic use case, no history.
|
||||
prompt = self.config.template.substitute(context=context_string, query=input_query)
|
||||
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
|
||||
return prompt
|
||||
|
||||
def _append_search_and_context(self, context: str, web_search_result: str) -> str:
|
||||
@@ -191,7 +191,7 @@ class BaseLlm(JSONSerializable):
|
||||
return contexts
|
||||
|
||||
if self.is_docs_site_instance:
|
||||
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
|
||||
self.config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
@@ -242,7 +242,7 @@ class BaseLlm(JSONSerializable):
|
||||
self.config = config
|
||||
|
||||
if self.is_docs_site_instance:
|
||||
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
|
||||
self.config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
|
||||
@@ -396,6 +396,7 @@ def validate_config(config_data):
|
||||
Optional("top_p"): Or(float, int),
|
||||
Optional("stream"): bool,
|
||||
Optional("template"): str,
|
||||
Optional("prompt"): str,
|
||||
Optional("system_prompt"): str,
|
||||
Optional("deployment_name"): str,
|
||||
Optional("where"): dict,
|
||||
|
||||
Reference in New Issue
Block a user