From cce6d5ddabd5d544a12e9d2d889daab478714368 Mon Sep 17 00:00:00 2001 From: aaishikdutta <107566376+aaishikdutta@users.noreply.github.com> Date: Wed, 26 Jul 2023 22:04:11 +0530 Subject: [PATCH] fix: Personapp not working with config (#368) --- embedchain/apps/PersonApp.py | 55 +++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/embedchain/apps/PersonApp.py b/embedchain/apps/PersonApp.py index 3a0378c1..a229c8db 100644 --- a/embedchain/apps/PersonApp.py +++ b/embedchain/apps/PersonApp.py @@ -22,6 +22,33 @@ class EmbedChainPersonApp: self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501 super().__init__(config) + def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None): + """ + This method checks if the config object contains a prompt template + if yes it adds the person prompt to it and return the updated config + else it creates a config object with the default prompt added to the person prompt + + :param default_prompt: it is the default prompt for query or chat methods + :param config: Optional. The `ChatConfig` instance to use as + configuration options. + """ + template = Template(self.person_prompt + " " + default_prompt) + + if config: + if config.template: + # Add person prompt to custom user template + config.template = Template(self.person_prompt + " " + config.template.template) + else: + # If no user template is present, use person prompt with the default template + config.template = template + else: + # if no config is present at all, initialize the config with person prompt and default template + config = QueryConfig( + template=template, + ) + + return config + class PersonApp(EmbedChainPersonApp, App): """ @@ -30,18 +57,12 @@ class PersonApp(EmbedChainPersonApp, App): """ def query(self, input_query, config: QueryConfig = None, dry_run=False): - self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT) - query_config = QueryConfig( - template=self.template, - ) - return super().query(input_query, query_config, dry_run) + config = self.add_person_template_to_config(DEFAULT_PROMPT, config) + return super().query(input_query, config, dry_run) def chat(self, input_query, config: ChatConfig = None, dry_run=False): - self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY) - chat_config = ChatConfig( - template=self.template, - ) - return super().chat(input_query, chat_config, dry_run) + config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config) + return super().chat(input_query, config, dry_run) class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp): @@ -51,15 +72,9 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp): """ def query(self, input_query, config: QueryConfig = None, dry_run=False): - self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT) - query_config = QueryConfig( - template=self.template, - ) - return super().query(input_query, query_config, dry_run) + config = self.add_person_template_to_config(DEFAULT_PROMPT, config) + return super().query(input_query, config, dry_run) def chat(self, input_query, config: ChatConfig = None, dry_run=False): - self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY) - chat_config = ChatConfig( - template=self.template, - ) - return super().chat(input_query, chat_config, dry_run) + config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config) + return super().chat(input_query, config, dry_run)