From 849de5e8abc3a3cabe638b4bbef15d8e8604e8e4 Mon Sep 17 00:00:00 2001 From: cachho Date: Wed, 16 Aug 2023 21:57:01 +0200 Subject: [PATCH] feat: system prompt (#448) --- docs/advanced/configuration.mdx | 2 +- docs/advanced/query_configuration.mdx | 2 ++ embedchain/apps/App.py | 2 ++ embedchain/apps/CustomApp.py | 18 +++++++++++------- embedchain/apps/Llama2App.py | 6 ++++-- embedchain/apps/OpenSourceApp.py | 3 +++ embedchain/config/ChatConfig.py | 5 +++++ embedchain/config/QueryConfig.py | 5 +++++ tests/embedchain/test_query.py | 17 +++++++++++++++++ 9 files changed, 50 insertions(+), 10 deletions(-) diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index 8cafe6a0..baf5a9f7 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -68,7 +68,7 @@ einstein_chat_template = Template(""" Human: $query Albert Einstein:""") -query_config = QueryConfig(template=einstein_chat_template) +query_config = QueryConfig(template=einstein_chat_template, system_prompt="You are Albert Einstein.") queries = [ "Where did you complete your studies?", "Why did you win nobel prize?", diff --git a/docs/advanced/query_configuration.mdx b/docs/advanced/query_configuration.mdx index 6ab51d3d..23f47815 100644 --- a/docs/advanced/query_configuration.mdx +++ b/docs/advanced/query_configuration.mdx @@ -65,6 +65,8 @@ _coming soon_ |top_p|Controls the diversity of words. Higher values (closer to 1) make word selection more diverse, lower values make words less diverse.|float|1| |history|include conversation history from your client or database.|any (recommendation: list[str])|None| |stream|control if response is streamed back to the user.|bool|False| +|deployment_name|t.b.a.|str|None| +|system_prompt|System prompt string. Unused if none.|str|None| ## ChatConfig diff --git a/embedchain/apps/App.py b/embedchain/apps/App.py index 057ad92c..3e6537e3 100644 --- a/embedchain/apps/App.py +++ b/embedchain/apps/App.py @@ -25,6 +25,8 @@ class App(EmbedChain): def get_llm_model_answer(self, prompt, config: ChatConfig): messages = [] + if config.system_prompt: + messages.append({"role": "system", "content": config.system_prompt}) messages.append({"role": "user", "content": prompt}) response = openai.ChatCompletion.create( model=config.model or "gpt-3.5-turbo-0613", diff --git a/embedchain/apps/CustomApp.py b/embedchain/apps/CustomApp.py index 6b3cf30e..a4cd78af 100644 --- a/embedchain/apps/CustomApp.py +++ b/embedchain/apps/CustomApp.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Optional from langchain.schema import BaseMessage @@ -84,7 +84,7 @@ class CustomApp(EmbedChain): if config.top_p and config.top_p != 1: logging.warning("Config option `top_p` is not supported by this model.") - messages = CustomApp._get_messages(prompt) + messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) return chat(messages).content @@ -97,7 +97,7 @@ class CustomApp(EmbedChain): if config.max_tokens and config.max_tokens != 1000: logging.warning("Config option `max_tokens` is not supported by this model.") - messages = CustomApp._get_messages(prompt) + messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) return chat(messages).content @@ -110,7 +110,7 @@ class CustomApp(EmbedChain): if config.top_p and config.top_p != 1: logging.warning("Config option `top_p` is not supported by this model.") - messages = CustomApp._get_messages(prompt) + messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) return chat(messages).content @@ -133,15 +133,19 @@ class CustomApp(EmbedChain): if config.top_p and config.top_p != 1: logging.warning("Config option `top_p` is not supported by this model.") - messages = CustomApp._get_messages(prompt) + messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt) return chat(messages).content @staticmethod - def _get_messages(prompt: str) -> List[BaseMessage]: + def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: from langchain.schema import HumanMessage, SystemMessage - return [SystemMessage(content="You are a helpful assistant."), HumanMessage(content=prompt)] + messages = [] + if system_prompt: + messages.append(SystemMessage(content=system_prompt)) + messages.append(HumanMessage(content=prompt)) + return messages def _stream_llm_model_response(self, response): """ diff --git a/embedchain/apps/Llama2App.py b/embedchain/apps/Llama2App.py index b1a13341..b9615cf8 100644 --- a/embedchain/apps/Llama2App.py +++ b/embedchain/apps/Llama2App.py @@ -2,7 +2,7 @@ import os from langchain.llms import Replicate -from embedchain.config import AppConfig +from embedchain.config import AppConfig, ChatConfig from embedchain.embedchain import EmbedChain @@ -27,8 +27,10 @@ class Llama2App(EmbedChain): super().__init__(config) - def get_llm_model_answer(self, prompt, config: AppConfig = None): + def get_llm_model_answer(self, prompt, config: ChatConfig = None): # TODO: Move the model and other inputs into config + if config.system_prompt: + raise ValueError("Llama2App does not support `system_prompt`") llm = Replicate( model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", input={"temperature": 0.75, "max_length": 500, "top_p": 1}, diff --git a/embedchain/apps/OpenSourceApp.py b/embedchain/apps/OpenSourceApp.py index 803cf133..eaf04cab 100644 --- a/embedchain/apps/OpenSourceApp.py +++ b/embedchain/apps/OpenSourceApp.py @@ -55,6 +55,9 @@ class OpenSourceApp(EmbedChain): "OpenSourceApp does not support switching models at runtime. Please create a new app instance." ) + if config.system_prompt: + raise ValueError("OpenSourceApp does not support `system_prompt`") + response = self.instance.generate( prompt=prompt, streaming=config.stream, diff --git a/embedchain/config/ChatConfig.py b/embedchain/config/ChatConfig.py index 0149fd3a..fd195a90 100644 --- a/embedchain/config/ChatConfig.py +++ b/embedchain/config/ChatConfig.py @@ -1,4 +1,5 @@ from string import Template +from typing import Optional from embedchain.config.QueryConfig import QueryConfig @@ -34,6 +35,7 @@ class ChatConfig(QueryConfig): top_p=None, stream: bool = False, deployment_name=None, + system_prompt: Optional[str] = None, ): """ Initializes the ChatConfig instance. @@ -51,6 +53,8 @@ class ChatConfig(QueryConfig): (closer to 1) make word selection more diverse, lower values make words less diverse. :param stream: Optional. Control if response is streamed back to the user + :param deployment_name: t.b.a. + :param system_prompt: Optional. System prompt string. :raises ValueError: If the template is not valid as template should contain $context and $query and $history """ @@ -70,6 +74,7 @@ class ChatConfig(QueryConfig): history=[0], stream=stream, deployment_name=deployment_name, + system_prompt=system_prompt, ) def set_history(self, history): diff --git a/embedchain/config/QueryConfig.py b/embedchain/config/QueryConfig.py index ef06f597..3a0ceef7 100644 --- a/embedchain/config/QueryConfig.py +++ b/embedchain/config/QueryConfig.py @@ -1,5 +1,6 @@ import re from string import Template +from typing import Optional from embedchain.config.BaseConfig import BaseConfig @@ -63,6 +64,7 @@ class QueryConfig(BaseConfig): history=None, stream: bool = False, deployment_name=None, + system_prompt: Optional[str] = None, ): """ Initializes the QueryConfig instance. @@ -81,6 +83,8 @@ class QueryConfig(BaseConfig): diverse. :param history: Optional. A list of strings to consider as history. :param stream: Optional. Control if response is streamed back to user + :param deployment_name: t.b.a. + :param system_prompt: Optional. System prompt string. :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history). """ @@ -108,6 +112,7 @@ class QueryConfig(BaseConfig): self.model = model self.top_p = top_p if top_p else 1 self.deployment_name = deployment_name + self.system_prompt = system_prompt if self.validate_template(template): self.template = template diff --git a/tests/embedchain/test_query.py b/tests/embedchain/test_query.py index 5c47ce8c..da8f1b26 100644 --- a/tests/embedchain/test_query.py +++ b/tests/embedchain/test_query.py @@ -41,3 +41,20 @@ class TestApp(unittest.TestCase): self.assertEqual(mock_retrieve.call_args[0][0], "Test query") self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig) mock_answer.assert_called_once() + + @patch("openai.ChatCompletion.create") + def test_query_config_passing(self, mock_create): + mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response + + config = AppConfig() + chat_config = QueryConfig(system_prompt="Test system prompt") + app = App(config=config) + + app.get_llm_model_answer("Test query", chat_config) + + # Test systemp_prompt: Check that the 'create' method was called with the correct 'messages' argument + messages_arg = mock_create.call_args.kwargs["messages"] + self.assertEqual(messages_arg[0]["role"], "system") + self.assertEqual(messages_arg[0]["content"], "Test system prompt") + + # TODO: Add tests for other config variables