refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -2,9 +2,10 @@ from string import Template
from embedchain.apps.App import App
from embedchain.apps.OpenSourceApp import OpenSourceApp
from embedchain.config import ChatConfig, QueryConfig
from embedchain.config import BaseLlmConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
from embedchain.config.llm.base_llm_config import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.helper_classes.json_serializable import register_deserializable
@@ -23,7 +24,7 @@ class EmbedChainPersonApp:
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
super().__init__(config)
def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None):
def add_person_template_to_config(self, default_prompt: str, config: BaseLlmConfig = None):
"""
This method checks if the config object contains a prompt template
if yes it adds the person prompt to it and return the updated config
@@ -44,7 +45,7 @@ class EmbedChainPersonApp:
config.template = template
else:
# if no config is present at all, initialize the config with person prompt and default template
config = QueryConfig(
config = BaseLlmConfig(
template=template,
)
@@ -58,11 +59,11 @@ class PersonApp(EmbedChainPersonApp, App):
Extends functionality from EmbedChainPersonApp and App
"""
def query(self, input_query, config: QueryConfig = None, dry_run=False):
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
return super().query(input_query, config, dry_run, where=None)
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
return super().chat(input_query, config, dry_run, where)
@@ -74,10 +75,10 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
Extends functionality from EmbedChainPersonApp and OpenSourceApp
"""
def query(self, input_query, config: QueryConfig = None, dry_run=False):
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
return super().query(input_query, config, dry_run)
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False):
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
return super().chat(input_query, config, dry_run)