refactor: classes and configs (#528)
This commit is contained in:
0
embedchain/llm/__init__.py
Normal file
0
embedchain/llm/__init__.py
Normal file
29
embedchain/llm/antrophic_llm.py
Normal file
29
embedchain/llm/antrophic_llm.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AntrophicLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return AntrophicLlm._get_athrophic_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
|
||||
chat = ChatAnthropic(temperature=config.temperature, model=config.model)
|
||||
|
||||
if config.max_tokens and config.max_tokens != 1000:
|
||||
logging.warning("Config option `max_tokens` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
39
embedchain/llm/azure_openai_llm.py
Normal file
39
embedchain/llm/azure_openai_llm.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AzureOpenAiLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return AzureOpenAiLlm._get_azure_openai_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_azure_openai_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
if not config.deployment_name:
|
||||
raise ValueError("Deployment name must be provided for Azure OpenAI")
|
||||
|
||||
chat = AzureChatOpenAI(
|
||||
deployment_name=config.deployment_name,
|
||||
openai_api_version="2023-05-15",
|
||||
model_name=config.model or "gpt-3.5-turbo",
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
streaming=config.stream,
|
||||
)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
214
embedchain/llm/base_llm.py
Normal file
214
embedchain/llm/base_llm.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.schema import BaseMessage
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base_llm_config import (
|
||||
DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
|
||||
|
||||
class BaseLlm(JSONSerializable):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if config is None:
|
||||
self.config = BaseLlmConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
self.memory = ConversationBufferMemory()
|
||||
self.is_docs_site_instance = False
|
||||
self.online = False
|
||||
self.history: any = None
|
||||
|
||||
def get_llm_model_answer(self):
|
||||
"""
|
||||
Usually implemented by child class
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_history(self, history: any):
|
||||
self.history = history
|
||||
|
||||
def update_history(self):
|
||||
chat_history = self.memory.load_memory_variables({})["history"]
|
||||
if chat_history:
|
||||
self.set_history(chat_history)
|
||||
|
||||
def generate_prompt(self, input_query, contexts, **kwargs):
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
passed to an LLM
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param contexts: List of similar documents to the query used as context.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:return: The prompt
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
if not self.history:
|
||||
prompt = self.config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
# check if it's the default template without history
|
||||
if (
|
||||
not self.config._validate_template_history(self.config.template)
|
||||
and self.config.template.template == DEFAULT_PROMPT
|
||||
):
|
||||
# swap in the template with history
|
||||
prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
|
||||
context=context_string, query=input_query, history=self.history
|
||||
)
|
||||
elif not self.config._validate_template_history(self.config.template):
|
||||
logging.warning("Template does not include `$history` key. History is not included in prompt.")
|
||||
prompt = self.config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
prompt = self.config.template.substitute(
|
||||
context=context_string, query=input_query, history=self.history
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _append_search_and_context(self, context, web_search_result):
|
||||
return f"{context}\nWeb Search Result: {web_search_result}"
|
||||
|
||||
def get_answer_from_llm(self, prompt):
|
||||
"""
|
||||
Gets an answer based on the given query and context by passing it
|
||||
to an LLM.
|
||||
|
||||
:param query: The query to use.
|
||||
:param context: Similar documents to the query used as context.
|
||||
:return: The answer.
|
||||
"""
|
||||
|
||||
return self.get_llm_model_answer(prompt)
|
||||
|
||||
def access_search_and_get_results(self, input_query):
|
||||
from langchain.tools import DuckDuckGoSearchRun
|
||||
|
||||
search = DuckDuckGoSearchRun()
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def _stream_query_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def _stream_chat_response(self, answer):
|
||||
streamed_answer = ""
|
||||
for chunk in answer:
|
||||
streamed_answer = streamed_answer + chunk
|
||||
yield chunk
|
||||
self.memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
query_config = config or self.config
|
||||
|
||||
if self.is_docs_site_instance:
|
||||
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
query_config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
|
||||
if isinstance(answer, str):
|
||||
logging.info(f"Answer: {answer}")
|
||||
return answer
|
||||
else:
|
||||
return self._stream_query_response(answer)
|
||||
|
||||
def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
LLM as context to get the answer.
|
||||
|
||||
Maintains the whole conversation in memory.
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `LlmConfig` instance to use as configuration options.
|
||||
This is used for one method call. To persistently use a config, declare it during app init.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:param where: Optional. A dictionary of key-value pairs to filter the database results.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
query_config = config or self.config
|
||||
|
||||
if self.is_docs_site_instance:
|
||||
query_config.template = DOCS_SITE_PROMPT_TEMPLATE
|
||||
query_config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
|
||||
self.update_history()
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
|
||||
self.memory.chat_memory.add_user_message(input_query)
|
||||
|
||||
if isinstance(answer, str):
|
||||
self.memory.chat_memory.add_ai_message(answer)
|
||||
logging.info(f"Answer: {answer}")
|
||||
|
||||
# NOTE: Adding to history before and after. This could be seen as redundant.
|
||||
# If we change it, we have to change the tests (no big deal).
|
||||
self.update_history()
|
||||
|
||||
return answer
|
||||
else:
|
||||
# this is a streamed response and needs to be handled differently.
|
||||
return self._stream_chat_response(answer)
|
||||
|
||||
@staticmethod
|
||||
def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
return messages
|
||||
47
embedchain/llm/gpt4all_llm.py
Normal file
47
embedchain/llm/gpt4all_llm.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GPT4ALLLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
if self.config.model is None:
|
||||
self.config.model = "orca-mini-3b.ggmlv3.q4_0.bin"
|
||||
self.instance = GPT4ALLLlm._get_instance(self.config.model)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return self._get_gpt4all_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_instance(model):
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
|
||||
) from None
|
||||
|
||||
return GPT4All(model_name=model)
|
||||
|
||||
def _get_gpt4all_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
||||
if config.model and config.model != self.config.model:
|
||||
raise RuntimeError(
|
||||
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
|
||||
)
|
||||
|
||||
if config.system_prompt:
|
||||
raise ValueError("OpenSourceApp does not support `system_prompt`")
|
||||
|
||||
response = self.instance.generate(
|
||||
prompt=prompt,
|
||||
streaming=config.stream,
|
||||
top_p=config.top_p,
|
||||
max_tokens=config.max_tokens,
|
||||
temp=config.temperature,
|
||||
)
|
||||
return response
|
||||
27
embedchain/llm/llama2_llm.py
Normal file
27
embedchain/llm/llama2_llm.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain.llms import Replicate
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class Llama2Llm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
if "REPLICATE_API_TOKEN" not in os.environ:
|
||||
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
# TODO: Move the model and other inputs into config
|
||||
if self.config.system_prompt:
|
||||
raise ValueError("Llama2App does not support `system_prompt`")
|
||||
llm = Replicate(
|
||||
model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
|
||||
input={"temperature": self.config.temperature or 0.75, "max_length": 500, "top_p": self.config.top_p},
|
||||
)
|
||||
return llm(prompt)
|
||||
43
embedchain/llm/openai_llm.py
Normal file
43
embedchain/llm/openai_llm.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenAiLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
# NOTE: This class does not use langchain. One reason is that `top_p` is not supported.
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
messages = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.config.model or "gpt-3.5-turbo-0613",
|
||||
messages=messages,
|
||||
temperature=self.config.temperature,
|
||||
max_tokens=self.config.max_tokens,
|
||||
top_p=self.config.top_p,
|
||||
stream=self.config.stream,
|
||||
)
|
||||
|
||||
if self.config.stream:
|
||||
return self._stream_llm_model_response(response)
|
||||
else:
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
for line in response:
|
||||
chunk = line["choices"][0].get("delta", {}).get("content", "")
|
||||
yield chunk
|
||||
29
embedchain/llm/vertex_ai_llm.py
Normal file
29
embedchain/llm/vertex_ai_llm.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.base_llm import BaseLlm
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class VertexAiLlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
return VertexAiLlm._get_athrophic_answer(prompt=prompt, config=self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_athrophic_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
from langchain.chat_models import ChatVertexAI
|
||||
|
||||
chat = ChatVertexAI(temperature=config.temperature, model=config.model)
|
||||
|
||||
if config.top_p and config.top_p != 1:
|
||||
logging.warning("Config option `top_p` is not supported by this model.")
|
||||
|
||||
messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
|
||||
|
||||
return chat(messages).content
|
||||
Reference in New Issue
Block a user