System prompt at App level (#484)

Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
This commit is contained in:
Dev Khant
2023-09-04 00:55:43 +05:30
committed by GitHub
parent 9f1f17a611
commit ec9f454ad1
6 changed files with 50 additions and 16 deletions

View File

@@ -1,3 +1,5 @@
from typing import Optional
import openai import openai
from embedchain.config import AppConfig, ChatConfig from embedchain.config import AppConfig, ChatConfig
@@ -14,19 +16,27 @@ class App(EmbedChain):
dry_run(query): test your prompt without consuming tokens. dry_run(query): test your prompt without consuming tokens.
""" """
def __init__(self, config: AppConfig = None): def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
""" """
:param config: AppConfig instance to load as configuration. Optional. :param config: AppConfig instance to load as configuration. Optional.
:param system_prompt: System prompt string. Optional.
""" """
if config is None: if config is None:
config = AppConfig() config = AppConfig()
super().__init__(config) super().__init__(config, system_prompt)
def get_llm_model_answer(self, prompt, config: ChatConfig): def get_llm_model_answer(self, prompt, config: ChatConfig):
messages = [] messages = []
if config.system_prompt: system_prompt = (
messages.append({"role": "system", "content": config.system_prompt}) self.system_prompt
if self.system_prompt is not None
else config.system_prompt
if config.system_prompt is not None
else None
)
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model=config.model or "gpt-3.5-turbo-0613", model=config.model or "gpt-3.5-turbo-0613",

View File

@@ -18,10 +18,11 @@ class CustomApp(EmbedChain):
dry_run(query): test your prompt without consuming tokens. dry_run(query): test your prompt without consuming tokens.
""" """
def __init__(self, config: CustomAppConfig = None): def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
""" """
:param config: Optional. `CustomAppConfig` instance to load as configuration. :param config: Optional. `CustomAppConfig` instance to load as configuration.
:raises ValueError: Config must be provided for custom app :raises ValueError: Config must be provided for custom app
:param system_prompt: Optional. System prompt string.
""" """
if config is None: if config is None:
raise ValueError("Config must be provided for custom app") raise ValueError("Config must be provided for custom app")
@@ -34,7 +35,7 @@ class CustomApp(EmbedChain):
# Because these models run locally, they should have an instance running when the custom app is created # Because these models run locally, they should have an instance running when the custom app is created
self.open_source_app = OpenSourceApp(config=config.open_source_app_config) self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
super().__init__(config) super().__init__(config, system_prompt)
def set_llm_model(self, provider: Providers): def set_llm_model(self, provider: Providers):
self.provider = provider self.provider = provider
@@ -51,6 +52,9 @@ class CustomApp(EmbedChain):
"Streaming responses have not been implemented for this model yet. Please disable." "Streaming responses have not been implemented for this model yet. Please disable."
) )
if config.system_prompt is None and self.system_prompt is not None:
config.system_prompt = self.system_prompt
try: try:
if self.provider == Providers.OPENAI: if self.provider == Providers.OPENAI:
return CustomApp._get_openai_answer(prompt, config) return CustomApp._get_openai_answer(prompt, config)

View File

@@ -1,4 +1,5 @@
import os import os
from typing import Optional
from langchain.llms import Replicate from langchain.llms import Replicate
@@ -15,9 +16,10 @@ class Llama2App(EmbedChain):
query(query): finds answer to the given query using vector database and LLM. query(query): finds answer to the given query using vector database and LLM.
""" """
def __init__(self, config: AppConfig = None): def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
""" """
:param config: AppConfig instance to load as configuration. Optional. :param config: AppConfig instance to load as configuration. Optional.
:param system_prompt: System prompt string. Optional.
""" """
if "REPLICATE_API_TOKEN" not in os.environ: if "REPLICATE_API_TOKEN" not in os.environ:
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.") raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
@@ -25,11 +27,11 @@ class Llama2App(EmbedChain):
if config is None: if config is None:
config = AppConfig() config = AppConfig()
super().__init__(config) super().__init__(config, system_prompt)
def get_llm_model_answer(self, prompt, config: ChatConfig = None): def get_llm_model_answer(self, prompt, config: ChatConfig = None):
# TODO: Move the model and other inputs into config # TODO: Move the model and other inputs into config
if config.system_prompt: if self.system_prompt or config.system_prompt:
raise ValueError("Llama2App does not support `system_prompt`") raise ValueError("Llama2App does not support `system_prompt`")
llm = Replicate( llm = Replicate(
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5", model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",

View File

@@ -1,5 +1,5 @@
import logging import logging
from typing import Iterable, Union from typing import Iterable, Union, Optional
from embedchain.config import ChatConfig, OpenSourceAppConfig from embedchain.config import ChatConfig, OpenSourceAppConfig
from embedchain.embedchain import EmbedChain from embedchain.embedchain import EmbedChain
@@ -18,10 +18,11 @@ class OpenSourceApp(EmbedChain):
query(query): finds answer to the given query using vector database and LLM. query(query): finds answer to the given query using vector database and LLM.
""" """
def __init__(self, config: OpenSourceAppConfig = None): def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
""" """
:param config: OpenSourceAppConfig instance to load as configuration. Optional. :param config: OpenSourceAppConfig instance to load as configuration. Optional.
`ef` defaults to open source. `ef` defaults to open source.
:param system_prompt: System prompt string. Optional.
""" """
logging.info("Loading open source embedding model. This may take some time...") # noqa:E501 logging.info("Loading open source embedding model. This may take some time...") # noqa:E501
if not config: if not config:
@@ -33,7 +34,7 @@ class OpenSourceApp(EmbedChain):
self.instance = OpenSourceApp._get_instance(config.model) self.instance = OpenSourceApp._get_instance(config.model)
logging.info("Successfully loaded open source embedding model.") logging.info("Successfully loaded open source embedding model.")
super().__init__(config) super().__init__(config, system_prompt)
def get_llm_model_answer(self, prompt, config: ChatConfig): def get_llm_model_answer(self, prompt, config: ChatConfig):
return self._get_gpt4all_answer(prompt=prompt, config=config) return self._get_gpt4all_answer(prompt=prompt, config=config)
@@ -55,7 +56,7 @@ class OpenSourceApp(EmbedChain):
"OpenSourceApp does not support switching models at runtime. Please create a new app instance." "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
) )
if config.system_prompt: if self.system_prompt or config.system_prompt:
raise ValueError("OpenSourceApp does not support `system_prompt`") raise ValueError("OpenSourceApp does not support `system_prompt`")
response = self.instance.generate( response = self.instance.generate(

View File

@@ -33,15 +33,17 @@ CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
class EmbedChain: class EmbedChain:
def __init__(self, config: BaseAppConfig): def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
""" """
Initializes the EmbedChain instance, sets up a vector DB client and Initializes the EmbedChain instance, sets up a vector DB client and
creates a collection. creates a collection.
:param config: BaseAppConfig instance to load as configuration. :param config: BaseAppConfig instance to load as configuration.
:param system_prompt: Optional. System prompt string.
""" """
self.config = config self.config = config
self.system_prompt = system_prompt
self.collection = self.config.db._get_or_create_collection(self.config.collection_name) self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
self.db = self.config.db self.db = self.config.db
self.user_asks = [] self.user_asks = []

View File

@@ -43,7 +43,7 @@ class TestApp(unittest.TestCase):
mock_answer.assert_called_once() mock_answer.assert_called_once()
@patch("openai.ChatCompletion.create") @patch("openai.ChatCompletion.create")
def test_query_config_passing(self, mock_create): def test_query_config_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig() config = AppConfig()
@@ -52,9 +52,24 @@ class TestApp(unittest.TestCase):
app.get_llm_model_answer("Test query", chat_config) app.get_llm_model_answer("Test query", chat_config)
# Test systemp_prompt: Check that the 'create' method was called with the correct 'messages' argument # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"] messages_arg = mock_create.call_args.kwargs["messages"]
self.assertEqual(messages_arg[0]["role"], "system") self.assertEqual(messages_arg[0]["role"], "system")
self.assertEqual(messages_arg[0]["content"], "Test system prompt") self.assertEqual(messages_arg[0]["content"], "Test system prompt")
# TODO: Add tests for other config variables # TODO: Add tests for other config variables
@patch("openai.ChatCompletion.create")
def test_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig()
chat_config = QueryConfig()
app = App(config=config, system_prompt="Test system prompt")
app.get_llm_model_answer("Test query", chat_config)
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertEqual(messages_arg[0]["role"], "system")
self.assertEqual(messages_arg[0]["content"], "Test system prompt")