fix: Personapp not working with config (#368)
This commit is contained in:
@@ -22,6 +22,33 @@ class EmbedChainPersonApp:
|
|||||||
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
|
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None):
|
||||||
|
"""
|
||||||
|
This method checks if the config object contains a prompt template
|
||||||
|
if yes it adds the person prompt to it and return the updated config
|
||||||
|
else it creates a config object with the default prompt added to the person prompt
|
||||||
|
|
||||||
|
:param default_prompt: it is the default prompt for query or chat methods
|
||||||
|
:param config: Optional. The `ChatConfig` instance to use as
|
||||||
|
configuration options.
|
||||||
|
"""
|
||||||
|
template = Template(self.person_prompt + " " + default_prompt)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
if config.template:
|
||||||
|
# Add person prompt to custom user template
|
||||||
|
config.template = Template(self.person_prompt + " " + config.template.template)
|
||||||
|
else:
|
||||||
|
# If no user template is present, use person prompt with the default template
|
||||||
|
config.template = template
|
||||||
|
else:
|
||||||
|
# if no config is present at all, initialize the config with person prompt and default template
|
||||||
|
config = QueryConfig(
|
||||||
|
template=template,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
class PersonApp(EmbedChainPersonApp, App):
|
class PersonApp(EmbedChainPersonApp, App):
|
||||||
"""
|
"""
|
||||||
@@ -30,18 +57,12 @@ class PersonApp(EmbedChainPersonApp, App):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||||
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
|
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
|
||||||
query_config = QueryConfig(
|
return super().query(input_query, config, dry_run)
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().query(input_query, query_config, dry_run)
|
|
||||||
|
|
||||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
||||||
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
|
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||||
chat_config = ChatConfig(
|
return super().chat(input_query, config, dry_run)
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().chat(input_query, chat_config, dry_run)
|
|
||||||
|
|
||||||
|
|
||||||
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
||||||
@@ -51,15 +72,9 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||||
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
|
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
|
||||||
query_config = QueryConfig(
|
return super().query(input_query, config, dry_run)
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().query(input_query, query_config, dry_run)
|
|
||||||
|
|
||||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
||||||
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
|
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||||
chat_config = ChatConfig(
|
return super().chat(input_query, config, dry_run)
|
||||||
template=self.template,
|
|
||||||
)
|
|
||||||
return super().chat(input_query, chat_config, dry_run)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user