Migrate from template to prompt arg while keeping backward compatibility (#1066)

This commit is contained in:
Sidharth Mohanty
2023-12-28 23:36:33 +05:30
committed by GitHub
parent 12e6eaf802
commit d9d529987e
9 changed files with 56 additions and 42 deletions

View File

@@ -1,3 +1,4 @@
import logging
import re
from string import Template
from typing import Any, Dict, List, Optional
@@ -59,6 +60,7 @@ class BaseLlmConfig(BaseConfig):
self,
number_documents: int = 3,
template: Optional[Template] = None,
prompt: Optional[Template] = None,
model: Optional[str] = None,
temperature: float = 0,
max_tokens: int = 1000,
@@ -80,8 +82,11 @@ class BaseLlmConfig(BaseConfig):
context, defaults to 1
:type number_documents: int, optional
:param template: The `Template` instance to use as a template for
prompt, defaults to None
prompt, defaults to None (deprecated)
:type template: Optional[Template], optional
:param prompt: The `Template` instance to use as a template for
prompt, defaults to None
:type prompt: Optional[Template], optional
:param model: Controls the OpenAI model used, defaults to None
:type model: Optional[str], optional
:param temperature: Controls the randomness of the model's output.
@@ -106,8 +111,16 @@ class BaseLlmConfig(BaseConfig):
contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean
"""
if template is None:
template = DEFAULT_PROMPT_TEMPLATE
if template is not None:
logging.warning(
"The `template` argument is deprecated and will be removed in a future version. "
+ "Please use `prompt` instead."
)
if prompt is None:
prompt = template
if prompt is None:
prompt = DEFAULT_PROMPT_TEMPLATE
self.number_documents = number_documents
self.temperature = temperature
@@ -120,37 +133,37 @@ class BaseLlmConfig(BaseConfig):
self.callbacks = callbacks
self.api_key = api_key
if type(template) is str:
template = Template(template)
if type(prompt) is str:
prompt = Template(prompt)
if self.validate_template(template):
self.template = template
if self.validate_prompt(prompt):
self.prompt = prompt
else:
raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
if not isinstance(stream, bool):
raise ValueError("`stream` should be bool")
self.stream = stream
self.where = where
def validate_template(self, template: Template) -> bool:
def validate_prompt(self, prompt: Template) -> bool:
"""
validate the template
validate the prompt
:param template: the template to validate
:type template: Template
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
"""
return re.search(query_re, template.template) and re.search(context_re, template.template)
return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
def _validate_template_history(self, template: Template) -> bool:
def _validate_prompt_history(self, prompt: Template) -> bool:
"""
validate the template with history
validate the prompt with history
:param template: the template to validate
:type template: Template
:param prompt: the prompt to validate
:type prompt: Template
:return: valid (true) or invalid (false)
:rtype: bool
"""
return re.search(history_re, template.template)
return re.search(history_re, prompt.template)