refactor: classes and configs (#528)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChromaDbConfig)
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.openai_embedder import OpenAiEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.llm.openai_llm import OpenAiLlm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
@@ -18,7 +20,13 @@ class App(EmbedChain):
|
||||
dry_run(query): test your prompt without consuming tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: AppConfig = None,
|
||||
llm_config: BaseLlmConfig = None,
|
||||
chromadb_config: Optional[ChromaDbConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param config: AppConfig instance to load as configuration. Optional.
|
||||
:param system_prompt: System prompt string. Optional.
|
||||
@@ -26,38 +34,8 @@ class App(EmbedChain):
|
||||
if config is None:
|
||||
config = AppConfig()
|
||||
|
||||
super().__init__(config, system_prompt)
|
||||
llm = OpenAiLlm(config=llm_config)
|
||||
embedder = OpenAiEmbedder(config=BaseEmbedderConfig(model="text-embedding-ada-002"))
|
||||
database = ChromaDB(config=chromadb_config)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
messages = []
|
||||
system_prompt = (
|
||||
self.system_prompt
|
||||
if self.system_prompt is not None
|
||||
else config.system_prompt
|
||||
if config.system_prompt is not None
|
||||
else None
|
||||
)
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = openai.ChatCompletion.create(
|
||||
model=config.model or "gpt-3.5-turbo-0613",
|
||||
messages=messages,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
stream=config.stream,
|
||||
)
|
||||
|
||||
if config.stream:
|
||||
return self._stream_llm_model_response(response)
|
||||
else:
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
for line in response:
|
||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||
yield chunk
|
||||
super().__init__(config, llm, db=database, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from embedchain.config import ChatConfig, CustomAppConfig
|
||||
from embedchain.config import CustomAppConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base_embedder import BaseEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models import Providers
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
@@ -20,143 +19,49 @@ class CustomApp(EmbedChain):
|
||||
dry_run(query): test your prompt without consuming tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: CustomAppConfig = None,
|
||||
llm: BaseLlm = None,
|
||||
db: BaseVectorDB = None,
|
||||
embedder: BaseEmbedder = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param config: Optional. `CustomAppConfig` instance to load as configuration.
|
||||
:raises ValueError: Config must be provided for custom app
|
||||
:param system_prompt: Optional. System prompt string.
|
||||
"""
|
||||
# Config is not required, it has a default
|
||||
if config is None:
|
||||
raise ValueError("Config must be provided for custom app")
|
||||
config = CustomAppConfig()
|
||||
|
||||
self.provider = config.provider
|
||||
if llm is None:
|
||||
raise ValueError("LLM must be provided for custom app. Please import from `embedchain.llm`.")
|
||||
if db is None:
|
||||
raise ValueError("Database must be provided for custom app. Please import from `embedchain.vectordb`.")
|
||||
if embedder is None:
|
||||
raise ValueError("Embedder must be provided for custom app. Please import from `embedchain.embedder`.")
|
||||
|
||||
if config.provider == Providers.GPT4ALL:
|
||||
from embedchain import OpenSourceApp
|
||||
|
||||
# Because these models run locally, they should have an instance running when the custom app is created
|
||||
self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
|
||||
|
||||
super().__init__(config, system_prompt)
|
||||
|
||||
def set_llm_model(self, provider: Providers):
|
||||
self.provider = provider
|
||||
if provider == Providers.GPT4ALL:
|
||||
raise ValueError(
|
||||
"GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
|
||||
if not isinstance(config, CustomAppConfig):
|
||||
raise TypeError(
|
||||
"Config is not a `CustomAppConfig` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(llm, BaseLlm):
|
||||
raise TypeError(
|
||||
"LLM is not a `BaseLlm` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(db, BaseVectorDB):
|
||||
raise TypeError(
|
||||
"Database is not a `BaseVectorDB` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
if not isinstance(embedder, BaseEmbedder):
|
||||
raise TypeError(
|
||||
"Embedder is not a `BaseEmbedder` instance. "
|
||||
"Please make sure the type is right and that you are passing an instance."
|
||||
)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
# TODO: Quitting the streaming response here for now.
|
||||
# Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
|
||||
if config.stream:
|
||||
raise NotImplementedError(
|
||||
"Streaming responses have not been implemented for this model yet. Please disable."
|
||||
)
|
||||
|
||||
if config.system_prompt is None and self.system_prompt is not None:
|
||||
config.system_prompt = self.system_prompt
|
||||
|
||||
try:
|
||||
if self.provider == Providers.OPENAI:
|
||||
return CustomApp._get_openai_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.ANTHROPHIC:
|
||||
return CustomApp._get_athrophic_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.VERTEX_AI:
|
||||
return CustomApp._get_vertex_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.GPT4ALL:
|
||||
return self.open_source_app._get_gpt4all_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.AZURE_OPENAI:
|
||||
return CustomApp._get_azure_openai_answer(prompt, config)
|
||||
|
||||
except ImportError as e:
|
||||
raise ModuleNotFoundError(e.msg) from None
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
chat = ChatOpenAI(
|
||||
temperature=config.temperature,
|
||||
model=config.model or "gpt-3.5-turbo",
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logging.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
|
||||
chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_azure_openai_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
if not config.deployment_name:
|
||||
raise ValueError("Deployment name must be provided for Azure OpenAI")
|
||||
|
||||
chat = AzureChatOpenAI(
|
||||
deployment_name=config.deployment_name,
|
||||
openai_api_version="2023-05-15",
|
||||
model_name=config.model or "gpt-3.5-turbo",
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
return messages
|
||||
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
for line in response:
|
||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||
yield chunk
|
||||
super().__init__(config=config, llm=llm, db=db, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms import Replicate
|
||||
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.apps.CustomApp import CustomApp
|
||||
from embedchain.config import CustomAppConfig
|
||||
from embedchain.embedder.openai_embedder import OpenAiEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.llm.llama2_llm import Llama2Llm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
|
||||
class Llama2App(EmbedChain):
|
||||
@register_deserializable
|
||||
class Llama2App(CustomApp):
|
||||
"""
|
||||
The EmbedChain Llama2App class.
|
||||
Has two functions: add and query.
|
||||
@@ -16,25 +18,15 @@ class Llama2App(EmbedChain):
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
|
||||
"""
|
||||
:param config: AppConfig instance to load as configuration. Optional.
|
||||
:param config: CustomAppConfig instance to load as configuration. Optional.
|
||||
:param system_prompt: System prompt string. Optional.
|
||||
"""
|
||||
if "REPLICATE_API_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
|
||||
|
||||
if config is None:
|
||||
config = AppConfig()
|
||||
config = CustomAppConfig()
|
||||
|
||||
super().__init__(config, system_prompt)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig = None):
|
||||
# TODO: Move the model and other inputs into config
|
||||
if self.system_prompt or config.system_prompt:
|
||||
raise ValueError("Llama2App does not support `system_prompt`")
|
||||
llm = Replicate(
|
||||
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
|
||||
input={"temperature": 0.75, "max_length": 500, "top_p": 1},
|
||||
super().__init__(
|
||||
config=config, llm=Llama2Llm(), db=ChromaDB(), embedder=OpenAiEmbedder(), system_prompt=system_prompt
|
||||
)
|
||||
return llm(prompt)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import logging
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
||||
from embedchain.config import (BaseEmbedderConfig, BaseLlmConfig,
|
||||
ChromaDbConfig, OpenSourceAppConfig)
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.llm.gpt4all_llm import GPT4ALLLlm
|
||||
from embedchain.vectordb.chroma_db import ChromaDB
|
||||
|
||||
gpt4all_model = None
|
||||
|
||||
@@ -20,7 +24,12 @@ class OpenSourceApp(EmbedChain):
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: OpenSourceAppConfig = None,
|
||||
chromadb_config: Optional[ChromaDbConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
:param config: OpenSourceAppConfig instance to load as configuration. Optional.
|
||||
`ef` defaults to open source.
|
||||
@@ -30,42 +39,19 @@ class OpenSourceApp(EmbedChain):
|
||||
if not config:
|
||||
config = OpenSourceAppConfig()
|
||||
|
||||
if not isinstance(config, OpenSourceAppConfig):
|
||||
raise ValueError(
|
||||
"OpenSourceApp needs a OpenSourceAppConfig passed to it. "
|
||||
"You can import it with `from embedchain.config import OpenSourceAppConfig`"
|
||||
)
|
||||
|
||||
if not config.model:
|
||||
raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
|
||||
|
||||
self.instance = OpenSourceApp._get_instance(config.model)
|
||||
|
||||
logging.info("Successfully loaded open source embedding model.")
|
||||
super().__init__(config, system_prompt)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
return self._get_gpt4all_answer(prompt=prompt, config=config)
|
||||
llm = GPT4ALLLlm(config=BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin"))
|
||||
embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2"))
|
||||
database = ChromaDB(config=chromadb_config)
|
||||
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
|
||||
) from None
|
||||
|
||||
return GPT4All(model)
|
||||
|
||||
def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
|
||||
if config.model and config.model != self.config.model:
|
||||
raise RuntimeError(
|
||||
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
|
||||
)
|
||||
|
||||
if self.system_prompt or config.system_prompt:
|
||||
raise ValueError("OpenSourceApp does not support `system_prompt`")
|
||||
|
||||
response = self.instance.generate(
|
||||
prompt=prompt,
|
||||
streaming=config.stream,
|
||||
top_p=config.top_p,
|
||||
max_tokens=config.max_tokens,
|
||||
temp=config.temperature,
|
||||
)
|
||||
return response
|
||||
super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt)
|
||||
|
||||
@@ -2,9 +2,10 @@ from string import Template
|
||||
|
||||
from embedchain.apps.App import App
|
||||
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
||||
from embedchain.config import ChatConfig, QueryConfig
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
||||
from embedchain.config.llm.base_llm_config import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY)
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@@ -23,7 +24,7 @@ class EmbedChainPersonApp:
|
||||
self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
|
||||
super().__init__(config)
|
||||
|
||||
def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None):
|
||||
def add_person_template_to_config(self, default_prompt: str, config: BaseLlmConfig = None):
|
||||
"""
|
||||
This method checks if the config object contains a prompt template
|
||||
if yes it adds the person prompt to it and return the updated config
|
||||
@@ -44,7 +45,7 @@ class EmbedChainPersonApp:
|
||||
config.template = template
|
||||
else:
|
||||
# if no config is present at all, initialize the config with person prompt and default template
|
||||
config = QueryConfig(
|
||||
config = BaseLlmConfig(
|
||||
template=template,
|
||||
)
|
||||
|
||||
@@ -58,11 +59,11 @@ class PersonApp(EmbedChainPersonApp, App):
|
||||
Extends functionality from EmbedChainPersonApp and App
|
||||
"""
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
|
||||
return super().query(input_query, config, dry_run, where=None)
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
|
||||
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||
return super().chat(input_query, config, dry_run, where)
|
||||
|
||||
@@ -74,10 +75,10 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
||||
Extends functionality from EmbedChainPersonApp and OpenSourceApp
|
||||
"""
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||
def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
|
||||
return super().query(input_query, config, dry_run)
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
||||
def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False):
|
||||
config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
|
||||
return super().chat(input_query, config, dry_run)
|
||||
|
||||
Reference in New Issue
Block a user