diff --git a/configs/full-stack.yaml b/configs/full-stack.yaml index 1da28209..dc337957 100644 --- a/configs/full-stack.yaml +++ b/configs/full-stack.yaml @@ -15,7 +15,7 @@ llm: max_tokens: 1000 top_p: 1 stream: false - template: | + prompt: | Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index c1faf4d7..11eab691 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -26,7 +26,7 @@ llm: top_p: 1 stream: false api_key: sk-xxx - template: | + prompt: | Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -73,7 +73,7 @@ chunker: "max_tokens": 1000, "top_p": 1, "stream": false, - "template": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:", + "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:", "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.", "api_key": "sk-xxx" } @@ -117,7 +117,7 @@ config = { 'max_tokens': 1000, 'top_p': 1, 'stream': False, - 'template': ( + 'prompt': ( "Use the following pieces of context to answer the query at the end.\n" "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n" "$context\n\nQuery: $query\n\nHelpful Answer:" @@ -170,7 +170,7 @@ Alright, let's dive into what each key means in the yaml config above: - `max_tokens` (Integer): Controls how many tokens are used in the response. - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse. - `stream` (Boolean): Controls if the response is streamed back to the user (set to false). - - `template` (String): A custom template for the prompt that the model uses to generate responses. + - `prompt` (String): A prompt for the model to follow when generating responses, requires $context and $query variables. - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare. - `stream` (Boolean): Controls if the response is streamed back to the user (set to false). - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1 diff --git a/docs/examples/rest-api/create.mdx b/docs/examples/rest-api/create.mdx index 82cf5d73..6736bf7a 100644 --- a/docs/examples/rest-api/create.mdx +++ b/docs/examples/rest-api/create.mdx @@ -37,7 +37,7 @@ llm: max_tokens: 1000 top_p: 1 stream: false - template: | + prompt: | Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index bba4a20f..635bfd45 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -1,3 +1,4 @@ +import logging import re from string import Template from typing import Any, Dict, List, Optional @@ -59,6 +60,7 @@ class BaseLlmConfig(BaseConfig): self, number_documents: int = 3, template: Optional[Template] = None, + prompt: Optional[Template] = None, model: Optional[str] = None, temperature: float = 0, max_tokens: int = 1000, @@ -80,8 +82,11 @@ class BaseLlmConfig(BaseConfig): context, defaults to 1 :type number_documents: int, optional :param template: The `Template` instance to use as a template for - prompt, defaults to None + prompt, defaults to None (deprecated) :type template: Optional[Template], optional + :param prompt: The `Template` instance to use as a template for + prompt, defaults to None + :type prompt: Optional[Template], optional :param model: Controls the OpenAI model used, defaults to None :type model: Optional[str], optional :param temperature: Controls the randomness of the model's output. @@ -106,8 +111,16 @@ class BaseLlmConfig(BaseConfig): contain $context and $query (and optionally $history) :raises ValueError: Stream is not boolean """ - if template is None: - template = DEFAULT_PROMPT_TEMPLATE + if template is not None: + logging.warning( + "The `template` argument is deprecated and will be removed in a future version. " + + "Please use `prompt` instead." + ) + if prompt is None: + prompt = template + + if prompt is None: + prompt = DEFAULT_PROMPT_TEMPLATE self.number_documents = number_documents self.temperature = temperature @@ -120,37 +133,37 @@ class BaseLlmConfig(BaseConfig): self.callbacks = callbacks self.api_key = api_key - if type(template) is str: - template = Template(template) + if type(prompt) is str: + prompt = Template(prompt) - if self.validate_template(template): - self.template = template + if self.validate_prompt(prompt): + self.prompt = prompt else: - raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).") + raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).") if not isinstance(stream, bool): raise ValueError("`stream` should be bool") self.stream = stream self.where = where - def validate_template(self, template: Template) -> bool: + def validate_prompt(self, prompt: Template) -> bool: """ - validate the template + validate the prompt - :param template: the template to validate - :type template: Template + :param prompt: the prompt to validate + :type prompt: Template :return: valid (true) or invalid (false) :rtype: bool """ - return re.search(query_re, template.template) and re.search(context_re, template.template) + return re.search(query_re, prompt.template) and re.search(context_re, prompt.template) - def _validate_template_history(self, template: Template) -> bool: + def _validate_prompt_history(self, prompt: Template) -> bool: """ - validate the template with history + validate the prompt with history - :param template: the template to validate - :type template: Template + :param prompt: the prompt to validate + :type prompt: Template :return: valid (true) or invalid (false) :rtype: bool """ - return re.search(history_re, template.template) + return re.search(history_re, prompt.template) diff --git a/embedchain/llm/base.py b/embedchain/llm/base.py index d01c7b64..ff2d5d17 100644 --- a/embedchain/llm/base.py +++ b/embedchain/llm/base.py @@ -74,19 +74,19 @@ class BaseLlm(JSONSerializable): if web_search_result: context_string = self._append_search_and_context(context_string, web_search_result) - template_contains_history = self.config._validate_template_history(self.config.template) - if template_contains_history: - # Template contains history + prompt_contains_history = self.config._validate_prompt_history(self.config.prompt) + if prompt_contains_history: + # Prompt contains history # If there is no history yet, we insert `- no history -` - prompt = self.config.template.substitute( + prompt = self.config.prompt.substitute( context=context_string, query=input_query, history=self.history or "- no history -" ) - elif self.history and not template_contains_history: - # History is present, but not included in the template. - # check if it's the default template without history + elif self.history and not prompt_contains_history: + # History is present, but not included in the prompt. + # check if it's the default prompt without history if ( - not self.config._validate_template_history(self.config.template) - and self.config.template.template == DEFAULT_PROMPT + not self.config._validate_prompt_history(self.config.prompt) + and self.config.prompt.template == DEFAULT_PROMPT ): # swap in the template with history prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute( @@ -95,12 +95,12 @@ class BaseLlm(JSONSerializable): else: # If we can't swap in the default, we still proceed but tell users that the history is ignored. logging.warning( - "Your bot contains a history, but template does not include `$history` key. History is ignored." + "Your bot contains a history, but prompt does not include `$history` key. History is ignored." ) - prompt = self.config.template.substitute(context=context_string, query=input_query) + prompt = self.config.prompt.substitute(context=context_string, query=input_query) else: # basic use case, no history. - prompt = self.config.template.substitute(context=context_string, query=input_query) + prompt = self.config.prompt.substitute(context=context_string, query=input_query) return prompt def _append_search_and_context(self, context: str, web_search_result: str) -> str: @@ -191,7 +191,7 @@ class BaseLlm(JSONSerializable): return contexts if self.is_docs_site_instance: - self.config.template = DOCS_SITE_PROMPT_TEMPLATE + self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE self.config.number_documents = 5 k = {} if self.online: @@ -242,7 +242,7 @@ class BaseLlm(JSONSerializable): self.config = config if self.is_docs_site_instance: - self.config.template = DOCS_SITE_PROMPT_TEMPLATE + self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE self.config.number_documents = 5 k = {} if self.online: diff --git a/embedchain/utils.py b/embedchain/utils.py index b41030f3..0ebb2349 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -396,6 +396,7 @@ def validate_config(config_data): Optional("top_p"): Or(float, int), Optional("stream"): bool, Optional("template"): str, + Optional("prompt"): str, Optional("system_prompt"): str, Optional("deployment_name"): str, Optional("where"): dict, diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py index 3cbe2763..adcfe5f7 100644 --- a/tests/helper_classes/test_json_serializable.py +++ b/tests/helper_classes/test_json_serializable.py @@ -76,4 +76,4 @@ class TestJsonSerializable(unittest.TestCase): config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history.")) s = config.serialize() new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s) - self.assertEqual(config.template.template, new_config.template.template) + self.assertEqual(config.prompt.template, new_config.prompt.template) diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py index e2e56bdc..61e0ee01 100644 --- a/tests/llm/test_base_llm.py +++ b/tests/llm/test_base_llm.py @@ -25,7 +25,7 @@ def test_is_stream_bool(): def test_template_string_gets_converted_to_Template_instance(): config = BaseLlmConfig(template="test value $query $context") llm = BaseLlm(config=config) - assert isinstance(llm.config.template, Template) + assert isinstance(llm.config.prompt, Template) def test_is_get_llm_model_answer_implemented(): diff --git a/tests/llm/test_generate_prompt.py b/tests/llm/test_generate_prompt.py index 0058032d..31f5fddd 100644 --- a/tests/llm/test_generate_prompt.py +++ b/tests/llm/test_generate_prompt.py @@ -53,7 +53,7 @@ class TestGeneratePrompt(unittest.TestCase): result = self.app.llm.generate_prompt(input_query, contexts) # Assert - expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query) + expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query) self.assertEqual(result, expected_result) def test_generate_prompt_with_history(self): @@ -61,7 +61,7 @@ class TestGeneratePrompt(unittest.TestCase): Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute. """ config = BaseLlmConfig() - config.template = Template("Context: $context | Query: $query | History: $history") + config.prompt = Template("Context: $context | Query: $query | History: $history") self.app.llm.config = config self.app.llm.set_history(["Past context 1", "Past context 2"]) prompt = self.app.llm.generate_prompt("Test query", ["Test context"])