refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -1,12 +1,11 @@
import logging
from typing import List, Optional
from typing import Optional
from langchain.schema import BaseMessage
from embedchain.config import ChatConfig, CustomAppConfig
from embedchain.config import CustomAppConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base_embedder import BaseEmbedder
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models import Providers
from embedchain.llm.base_llm import BaseLlm
from embedchain.vectordb.base_vector_db import BaseVectorDB
@register_deserializable
@@ -20,143 +19,49 @@ class CustomApp(EmbedChain):
dry_run(query): test your prompt without consuming tokens.
"""
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
def __init__(
self,
config: CustomAppConfig = None,
llm: BaseLlm = None,
db: BaseVectorDB = None,
embedder: BaseEmbedder = None,
system_prompt: Optional[str] = None,
):
"""
:param config: Optional. `CustomAppConfig` instance to load as configuration.
:raises ValueError: Config must be provided for custom app
:param system_prompt: Optional. System prompt string.
"""
# Config is not required, it has a default
if config is None:
raise ValueError("Config must be provided for custom app")
config = CustomAppConfig()
self.provider = config.provider
if llm is None:
raise ValueError("LLM must be provided for custom app. Please import from `embedchain.llm`.")
if db is None:
raise ValueError("Database must be provided for custom app. Please import from `embedchain.vectordb`.")
if embedder is None:
raise ValueError("Embedder must be provided for custom app. Please import from `embedchain.embedder`.")
if config.provider == Providers.GPT4ALL:
from embedchain import OpenSourceApp
# Because these models run locally, they should have an instance running when the custom app is created
self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
super().__init__(config, system_prompt)
def set_llm_model(self, provider: Providers):
self.provider = provider
if provider == Providers.GPT4ALL:
raise ValueError(
"GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
if not isinstance(config, CustomAppConfig):
raise TypeError(
"Config is not a `CustomAppConfig` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(llm, BaseLlm):
raise TypeError(
"LLM is not a `BaseLlm` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(db, BaseVectorDB):
raise TypeError(
"Database is not a `BaseVectorDB` instance. "
"Please make sure the type is right and that you are passing an instance."
)
if not isinstance(embedder, BaseEmbedder):
raise TypeError(
"Embedder is not a `BaseEmbedder` instance. "
"Please make sure the type is right and that you are passing an instance."
)
def get_llm_model_answer(self, prompt, config: ChatConfig):
# TODO: Quitting the streaming response here for now.
# Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
if config.stream:
raise NotImplementedError(
"Streaming responses have not been implemented for this model yet. Please disable."
)
if config.system_prompt is None and self.system_prompt is not None:
config.system_prompt = self.system_prompt
try:
if self.provider == Providers.OPENAI:
return CustomApp._get_openai_answer(prompt, config)
if self.provider == Providers.ANTHROPHIC:
return CustomApp._get_athrophic_answer(prompt, config)
if self.provider == Providers.VERTEX_AI:
return CustomApp._get_vertex_answer(prompt, config)
if self.provider == Providers.GPT4ALL:
return self.open_source_app._get_gpt4all_answer(prompt, config)
if self.provider == Providers.AZURE_OPENAI:
return CustomApp._get_azure_openai_answer(prompt, config)
except ImportError as e:
raise ModuleNotFoundError(e.msg) from None
@staticmethod
def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import ChatOpenAI
chat = ChatOpenAI(
temperature=config.temperature,
model=config.model or "gpt-3.5-turbo",
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import ChatAnthropic
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
if config.max_tokens and config.max_tokens != 1000:
logging.warning("Config option `max_tokens` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import ChatVertexAI
chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
if not config.deployment_name:
raise ValueError("Deployment name must be provided for Azure OpenAI")
chat = AzureChatOpenAI(
deployment_name=config.deployment_name,
openai_api_version="2023-05-15",
model_name=config.model or "gpt-3.5-turbo",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
from langchain.schema import HumanMessage, SystemMessage
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(HumanMessage(content=prompt))
return messages
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
for line in response:
chunk = line["choices"][0].get("delta", {}).get("content", "")
yield chunk
super().__init__(config=config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)