refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

View File

@@ -0,0 +1,29 @@
import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class AntrophicLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return AntrophicLlm._get_athrophic_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import ChatAnthropic
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
if config.max_tokens and config.max_tokens != 1000:
logging.warning("Config option `max_tokens` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content

View File

@@ -0,0 +1,39 @@
import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class AzureOpenAiLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return AzureOpenAiLlm._get_azure_openai_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_azure_openai_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import AzureChatOpenAI
if not config.deployment_name:
raise ValueError("Deployment name must be provided for Azure OpenAI")
chat = AzureChatOpenAI(
deployment_name=config.deployment_name,
openai_api_version="2023-05-15",
model_name=config.model or "gpt-3.5-turbo",
temperature=config.temperature,
max_tokens=config.max_tokens,
streaming=config.stream,
)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content

214
embedchain/llm/base_llm.py Normal file
View File

@@ -0,0 +1,214 @@
import logging
from typing import List, Optional
from langchain.memory import ConversationBufferMemory
from langchain.schema import BaseMessage
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base_llm_config import (
DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
class BaseLlm(JSONSerializable):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if config is None:
self.config = BaseLlmConfig()
else:
self.config = config
self.memory = ConversationBufferMemory()
self.is_docs_site_instance = False
self.online = False
self.history: any = None
def get_llm_model_answer(self):
"""
Usually implemented by child class
"""
raise NotImplementedError
def set_history(self, history: any):
self.history = history
def update_history(self):
chat_history = self.memory.load_memory_variables({})["history"]
if chat_history:
self.set_history(chat_history)
def generate_prompt(self, input_query, contexts, **kwargs):
"""
Generates a prompt based on the given query and context, ready to be
passed to an LLM
:param input_query: The query to use.
:param contexts: List of similar documents to the query used as context.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:return: The prompt
"""
context_string = (" | ").join(contexts)
web_search_result = kwargs.get("web_search_result", "")
if web_search_result:
context_string = self._append_search_and_context(context_string, web_search_result)
if not self.history:
prompt = self.config.template.substitute(context=context_string, query=input_query)
else:
# check if it's the default template without history
if (
not self.config._validate_template_history(self.config.template)
and self.config.template.template == DEFAULT_PROMPT
):
# swap in the template with history
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
context=context_string, query=input_query, history=self.history
)
elif not self.config._validate_template_history(self.config.template):
logging.warning("Template does not include `$history` key. History is not included in prompt.")
prompt = self.config.template.substitute(context=context_string, query=input_query)
else:
prompt = self.config.template.substitute(
context=context_string, query=input_query, history=self.history
)
return prompt
def _append_search_and_context(self, context, web_search_result):
return f"{context}\nWeb Search Result: {web_search_result}"
def get_answer_from_llm(self, prompt):
"""
Gets an answer based on the given query and context by passing it
to an LLM.
:param query: The query to use.
:param context: Similar documents to the query used as context.
:return: The answer.
"""
return self.get_llm_model_answer(prompt)
def access_search_and_get_results(self, input_query):
from langchain.tools import DuckDuckGoSearchRun
search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def _stream_query_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
logging.info(f"Answer: {streamed_answer}")
def _stream_chat_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer = streamed_answer + chunk
yield chunk
self.memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
query_config = config or self.config
if self.is_docs_site_instance:
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
query_config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
if isinstance(answer, str):
logging.info(f"Answer: {answer}")
return answer
else:
return self._stream_query_response(answer)
def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
LLM as context to get the answer.
Maintains the whole conversation in memory.
:param input_query: The query to use.
:param config: Optional. The `LlmConfig` instance to use as configuration options.
This is used for one method call. To persistently use a config, declare it during app init.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param where: Optional. A dictionary of key-value pairs to filter the database results.
:return: The answer to the query.
"""
query_config = config or self.config
if self.is_docs_site_instance:
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
query_config.number_documents = 5
k = {}
if self.online:
k["web_search_result"] = self.access_search_and_get_results(input_query)
self.update_history()
prompt = self.generate_prompt(input_query, contexts, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt)
self.memory.chat_memory.add_user_message(input_query)
if isinstance(answer, str):
self.memory.chat_memory.add_ai_message(answer)
logging.info(f"Answer: {answer}")
# NOTE: Adding to history before and after. This could be seen as redundant.
# If we change it, we have to change the tests (no big deal).
self.update_history()
return answer
else:
# this is a streamed response and needs to be handled differently.
return self._stream_chat_response(answer)
@staticmethod
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
from langchain.schema import HumanMessage, SystemMessage
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(HumanMessage(content=prompt))
return messages

View File

@@ -0,0 +1,47 @@
from typing import Iterable, Optional, Union
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class GPT4ALLLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
if self.config.model is None:
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
self.instance = GPT4ALLLlm._get_instance(self.config.model)
def get_llm_model_answer(self, prompt):
return self._get_gpt4all_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_instance(model):
try:
from gpt4all import GPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
) from None
return GPT4All(model_name=model)
def _get_gpt4all_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
if config.model and config.model != self.config.model:
raise RuntimeError(
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
)
if config.system_prompt:
raise ValueError("OpenSourceApp does not support `system_prompt`")
response = self.instance.generate(
prompt=prompt,
streaming=config.stream,
top_p=config.top_p,
max_tokens=config.max_tokens,
temp=config.temperature,
)
return response

View File

@@ -0,0 +1,27 @@
import os
from typing import Optional
from langchain.llms import Replicate
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class Llama2Llm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "REPLICATE_API_TOKEN" not in os.environ:
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
# TODO: Move the model and other inputs into config
if self.config.system_prompt:
raise ValueError("Llama2App does not support `system_prompt`")
llm = Replicate(
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
input={"temperature": self.config.temperature or 0.75, "max_length": 500, "top_p": self.config.top_p},
)
return llm(prompt)

View File

@@ -0,0 +1,43 @@
from typing import Optional
import openai
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class OpenAiLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
# NOTE: This class does not use langchain. One reason is that `top_p` is not supported.
def get_llm_model_answer(self, prompt):
messages = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": prompt})
response = openai.ChatCompletion.create(
model=self.config.model or "gpt-3.5-turbo-0613",
messages=messages,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
top_p=self.config.top_p,
stream=self.config.stream,
)
if self.config.stream:
return self._stream_llm_model_response(response)
else:
return response["choices"][0]["message"]["content"]
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
for line in response:
chunk = line["choices"][0].get("delta", {}).get("content", "")
yield chunk

View File

@@ -0,0 +1,29 @@
import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class VertexAiLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return VertexAiLlm._get_athrophic_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain.chat_models import ChatVertexAI
chat = ChatVertexAI(temperature=config.temperature, model=config.model)
if config.top_p and config.top_p != 1:
logging.warning("Config option `top_p` is not supported by this model.")
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
return chat(messages).content