fixed dry_run not working in PersonApp (#357)

This commit is contained in:
aaishikdutta
2023-07-22 11:29:20 +05:30
committed by GitHub
parent acbdb800d3
commit c9c56a4b26

View File

@@ -29,19 +29,19 @@ class PersonApp(EmbedChainPersonApp, App):
Extends functionality from EmbedChainPersonApp and App Extends functionality from EmbedChainPersonApp and App
""" """
def query(self, input_query, config: QueryConfig = None): def query(self, input_query, config: QueryConfig = None, dry_run=False):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT) self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
query_config = QueryConfig( query_config = QueryConfig(
template=self.template, template=self.template,
) )
return super().query(input_query, query_config) return super().query(input_query, query_config, dry_run)
def chat(self, input_query, config: ChatConfig = None): def chat(self, input_query, config: ChatConfig = None, dry_run=False):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY) self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
chat_config = ChatConfig( chat_config = ChatConfig(
template=self.template, template=self.template,
) )
return super().chat(input_query, chat_config) return super().chat(input_query, chat_config, dry_run)
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp): class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
@@ -50,16 +50,16 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
Extends functionality from EmbedChainPersonApp and OpenSourceApp Extends functionality from EmbedChainPersonApp and OpenSourceApp
""" """
def query(self, input_query, config: QueryConfig = None): def query(self, input_query, config: QueryConfig = None, dry_run=False):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT) self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
query_config = QueryConfig( query_config = QueryConfig(
template=self.template, template=self.template,
) )
return super().query(input_query, query_config) return super().query(input_query, query_config, dry_run)
def chat(self, input_query, config: ChatConfig = None): def chat(self, input_query, config: ChatConfig = None, dry_run=False):
self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY) self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
chat_config = ChatConfig( chat_config = ChatConfig(
template=self.template, template=self.template,
) )
return super().chat(input_query, chat_config) return super().chat(input_query, chat_config, dry_run)