Migrate from template to prompt arg while keeping backward compatibility (#1066)

This commit is contained in:
Sidharth Mohanty
2023-12-28 23:36:33 +05:30
committed by GitHub
parent 12e6eaf802
commit d9d529987e
9 changed files with 56 additions and 42 deletions

View File

@@ -15,7 +15,7 @@ llm:
max_tokens: 1000
top_p: 1
stream: false
template: |
prompt: |
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

View File

@@ -26,7 +26,7 @@ llm:
top_p: 1
stream: false
api_key: sk-xxx
template: |
prompt: |
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -73,7 +73,7 @@ chunker:
"max_tokens": 1000,
"top_p": 1,
"stream": false,
"template": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
"prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
"system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
"api_key": "sk-xxx"
}
@@ -117,7 +117,7 @@ config = {
'max_tokens': 1000,
'top_p': 1,
'stream': False,
'template': (
'prompt': (
"Use the following pieces of context to answer the query at the end.\n"
"If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
"$context\n\nQuery: $query\n\nHelpful Answer:"
@@ -170,7 +170,7 @@ Alright, let's dive into what each key means in the yaml config above:
- `max_tokens` (Integer): Controls how many tokens are used in the response.
- `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `template` (String): A custom template for the prompt that the model uses to generate responses.
- `prompt` (String): A prompt for the model to follow when generating responses, requires $context and $query variables.
- `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1

View File

@@ -37,7 +37,7 @@ llm:
max_tokens: 1000
top_p: 1
stream: false
template: |
prompt: |
Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

View File

@@ -1,3 +1,4 @@
import logging
import re
from string import Template
from typing import Any, Dict, List, Optional
@@ -59,6 +60,7 @@ class BaseLlmConfig(BaseConfig):
self,
number_documents: int = 3,
template: Optional[Template] = None,
prompt: Optional[Template] = None,
model: Optional[str] = None,
temperature: float = 0,
max_tokens: int = 1000,
@@ -80,8 +82,11 @@ class BaseLlmConfig(BaseConfig):
context, defaults to 1
:type number_documents: int, optional
:param template: The `Template` instance to use as a template for
prompt, defaults to None
prompt, defaults to None (deprecated)
:type template: Optional[Template], optional
:param prompt: The `Template` instance to use as a template for
prompt, defaults to None
:type prompt: Optional[Template], optional
:param model: Controls the OpenAI model used, defaults to None
:type model: Optional[str], optional
:param temperature: Controls the randomness of the model's output.
@@ -106,8 +111,16 @@ class BaseLlmConfig(BaseConfig):
contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean
"""
if template is None:
template = DEFAULT_PROMPT_TEMPLATE
if template is not None:
logging.warning(
"The `template` argument is deprecated and will be removed in a future version. "
+ "Please use `prompt` instead."
)
if prompt is None:
prompt = template
if prompt is None:
prompt = DEFAULT_PROMPT_TEMPLATE
self.number_documents = number_documents
self.temperature = temperature
@@ -120,37 +133,37 @@ class BaseLlmConfig(BaseConfig):
self.callbacks = callbacks
self.api_key = api_key
if type(template) is str:
template = Template(template)
if type(prompt) is str:
prompt = Template(prompt)
if self.validate_template(template):
self.template = template
if self.validate_prompt(prompt):
self.prompt = prompt
else:
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
self.stream = stream
self.where = where
def validate_template(self, template: Template) -> bool:
def validate_prompt(self, prompt: Template) -> bool:
"""
validate the template
validate the prompt
:param template: the template to validate
:type template: Template
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
"""
return re.search(query_re, template.template) and re.search(context_re, template.template)
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
def _validate_template_history(self, template: Template) -> bool:
def _validate_prompt_history(self, prompt: Template) -> bool:
"""
validate the template with history
validate the prompt with history
:param template: the template to validate
:type template: Template
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
"""
return re.search(history_re, template.template)
return re.search(history_re, prompt.template)

View File

@@ -74,19 +74,19 @@ class BaseLlm(JSONSerializable):
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
template_contains_history = self.config._validate_template_history(self.config.template)
if template_contains_history:
# Template contains history
prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
if prompt_contains_history:
# Prompt contains history
# If there is no history yet, we insert `- no history -`
prompt = self.config.template.substitute(
prompt = self.config.prompt.substitute(
context=context_string, query=input_query, history=self.history or "- no history -"
)
elif self.history and not template_contains_history:
# History is present, but not included in the template.
# check if it's the default template without history
elif self.history and not prompt_contains_history:
# History is present, but not included in the prompt.
# check if it's the default prompt without history
if (
not self.config._validate_template_history(self.config.template)
and self.config.template.template == DEFAULT_PROMPT
not self.config._validate_prompt_history(self.config.prompt)
and self.config.prompt.template == DEFAULT_PROMPT
):
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
@@ -95,12 +95,12 @@ class BaseLlm(JSONSerializable):
else:
# If we can't swap in the default, we still proceed but tell users that the history is ignored.
logging.warning(
"Your bot contains a history, but template does not include `$history` key. History is ignored."
"Your bot contains a history, but prompt does not include `$history` key. History is ignored."
)
prompt = self.config.template.substitute(context=context_string, query=input_query)
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
else:
# basic use case, no history.
prompt = self.config.template.substitute(context=context_string, query=input_query)
prompt = self.config.prompt.substitute(context=context_string, query=input_query)
return prompt
def _append_search_and_context(self, context: str, web_search_result: str) -> str:
@@ -191,7 +191,7 @@ class BaseLlm(JSONSerializable):
return contexts
if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:
@@ -242,7 +242,7 @@ class BaseLlm(JSONSerializable):
self.config = config
if self.is_docs_site_instance:
self.config.template = DOCS_SITE_PROMPT_TEMPLATE
self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
self.config.number_documents = 5
k = {}
if self.online:

View File

@@ -396,6 +396,7 @@ def validate_config(config_data):
Optional("top_p"): Or(float, int),
Optional("stream"): bool,
Optional("template"): str,
Optional("prompt"): str,
Optional("system_prompt"): str,
Optional("deployment_name"): str,
Optional("where"): dict,

View File

@@ -76,4 +76,4 @@ class TestJsonSerializable(unittest.TestCase):
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
s = config.serialize()
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
self.assertEqual(config.template.template, new_config.template.template)
self.assertEqual(config.prompt.template, new_config.prompt.template)

View File

@@ -25,7 +25,7 @@ def test_is_stream_bool():
def test_template_string_gets_converted_to_Template_instance():
config = BaseLlmConfig(template="test value $query $context")
llm = BaseLlm(config=config)
assert isinstance(llm.config.template, Template)
assert isinstance(llm.config.prompt, Template)
def test_is_get_llm_model_answer_implemented():

View File

@@ -53,7 +53,7 @@ class TestGeneratePrompt(unittest.TestCase):
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_history(self):
@@ -61,7 +61,7 @@ class TestGeneratePrompt(unittest.TestCase):
Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
"""
config = BaseLlmConfig()
config.template = Template("Context: $context | Query: $query | History: $history")
config.prompt = Template("Context: $context | Query: $query | History: $history")
self.app.llm.config = config
self.app.llm.set_history(["Past context 1", "Past context 2"])
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])