feat: add new custom app (#313)
This commit is contained in:
@@ -27,7 +27,7 @@ class App(EmbedChain):
|
||||
messages = []
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = openai.ChatCompletion.create(
|
||||
model=config.model,
|
||||
model=config.model or "gpt-3.5-turbo-0613",
|
||||
messages=messages,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
|
||||
128
embedchain/apps/CustomApp.py
Normal file
128
embedchain/apps/CustomApp.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from embedchain.config import ChatConfig, CustomAppConfig, OpenSourceAppConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.models import Providers
|
||||
|
||||
|
||||
class CustomApp(EmbedChain):
|
||||
"""
|
||||
The custom EmbedChain app.
|
||||
Has two functions: add and query.
|
||||
|
||||
adds(data_type, url): adds the data from the given URL to the vector db.
|
||||
query(query): finds answer to the given query using vector database and LLM.
|
||||
dry_run(query): test your prompt without consuming tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, config: CustomAppConfig = None):
|
||||
"""
|
||||
:param config: Optional. `CustomAppConfig` instance to load as configuration.
|
||||
:raises ValueError: Config must be provided for custom app
|
||||
"""
|
||||
if config is None:
|
||||
raise ValueError("Config must be provided for custom app")
|
||||
|
||||
self.provider = config.provider
|
||||
|
||||
if config.provider == Providers.GPT4ALL:
|
||||
from embedchain import OpenSourceApp
|
||||
|
||||
# Because these models run locally, they should have an instance running when the custom app is created
|
||||
self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
def set_llm_model(self, provider: Providers):
|
||||
self.provider = provider
|
||||
if provider == Providers.GPT4ALL:
|
||||
raise ValueError(
|
||||
"GPT4ALL needs to be instantiated with the model known, please create a new app instance instead"
|
||||
)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
# TODO: Quitting the streaming response here for now.
|
||||
# Idea: https://gist.github.com/jvelezmagic/03ddf4c452d011aae36b2a0f73d72f68
|
||||
if config.stream:
|
||||
raise NotImplementedError(
|
||||
"Streaming responses have not been implemented for this model yet. Please disable."
|
||||
)
|
||||
|
||||
try:
|
||||
if self.provider == Providers.OPENAI:
|
||||
return CustomApp._get_openai_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.ANTHROPHIC:
|
||||
return CustomApp._get_athrophic_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.VERTEX_AI:
|
||||
return CustomApp._get_vertex_answer(prompt, config)
|
||||
|
||||
if self.provider == Providers.GPT4ALL:
|
||||
return self.open_source_app._get_gpt4all_answer(prompt, config)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(e.msg) from None
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
||||
logging.info(vars(config))
|
||||
|
||||
chat = ChatOpenAI(
|
||||
temperature=config.temperature,
|
||||
model=config.model or "gpt-3.5-turbo",
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_athrophic_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logging.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_answer(prompt: str, config: ChatConfig) -> str:
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
|
||||
chat = ChatVertexAI(temperature=config.temperature, model=config.model, max_output_tokens=config.max_tokens)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = CustomApp._get_messages(prompt)
|
||||
|
||||
return chat(messages).content
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str) -> List[BaseMessage]:
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
return [SystemMessage(content="You are a helpful assistant."), HumanMessage(content=prompt)]
|
||||
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
for line in response:
|
||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||
yield chunk
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
@@ -26,14 +27,39 @@ class OpenSourceApp(EmbedChain):
|
||||
if not config:
|
||||
config = OpenSourceAppConfig()
|
||||
|
||||
if not config.model:
|
||||
raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
|
||||
|
||||
self.instance = OpenSourceApp._get_instance(config.model)
|
||||
|
||||
logging.info("Successfully loaded open source embedding model.")
|
||||
super().__init__(config)
|
||||
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
from gpt4all import GPT4All
|
||||
return self._get_gpt4all_answer(prompt=prompt, config=config)
|
||||
|
||||
global gpt4all_model
|
||||
if gpt4all_model is None:
|
||||
gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||
response = gpt4all_model.generate(prompt=prompt, streaming=config.stream)
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ValueError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install GPT4All`"
|
||||
) from None
|
||||
|
||||
return GPT4All(model)
|
||||
|
||||
def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
|
||||
if config.model and config.model != self.config.model:
|
||||
raise RuntimeError(
|
||||
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
|
||||
)
|
||||
|
||||
response = self.instance.generate(
|
||||
prompt=prompt,
|
||||
streaming=config.stream,
|
||||
top_p=config.top_p,
|
||||
max_tokens=config.max_tokens,
|
||||
temp=config.temperature,
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -4,8 +4,7 @@ from embedchain.apps.App import App
|
||||
from embedchain.apps.OpenSourceApp import OpenSourceApp
|
||||
from embedchain.config import ChatConfig, QueryConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY)
|
||||
from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
|
||||
|
||||
|
||||
class EmbedChainPersonApp:
|
||||
|
||||
Reference in New Issue
Block a user