refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -1,9 +1,13 @@
import logging
from typing import Iterable, Optional, Union
from typing import Optional
from embedchain.config import ChatConfig, OpenSourceAppConfig
from embedchain.config import (BaseEmbedderConfig, BaseLlmConfig,
ChromaDbConfig, OpenSourceAppConfig)
from embedchain.embedchain import EmbedChain
from embedchain.embedder.gpt4all_embedder import GPT4AllEmbedder
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.llm.gpt4all_llm import GPT4ALLLlm
from embedchain.vectordb.chroma_db import ChromaDB
gpt4all_model = None
@@ -20,7 +24,12 @@ class OpenSourceApp(EmbedChain):
query(query): finds answer to the given query using vector database and LLM.
"""
def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
def __init__(
self,
config: OpenSourceAppConfig = None,
chromadb_config: Optional[ChromaDbConfig] = None,
system_prompt: Optional[str] = None,
):
"""
:param config: OpenSourceAppConfig instance to load as configuration. Optional.
`ef` defaults to open source.
@@ -30,42 +39,19 @@ class OpenSourceApp(EmbedChain):
if not config:
config = OpenSourceAppConfig()
if not isinstance(config, OpenSourceAppConfig):
raise ValueError(
"OpenSourceApp needs a OpenSourceAppConfig passed to it. "
"You can import it with `from embedchain.config import OpenSourceAppConfig`"
)
if not config.model:
raise ValueError("OpenSourceApp needs a model to be instantiated. Maybe you passed the wrong config type?")
self.instance = OpenSourceApp._get_instance(config.model)
logging.info("Successfully loaded open source embedding model.")
super().__init__(config, system_prompt)
def get_llm_model_answer(self, prompt, config: ChatConfig):
return self._get_gpt4all_answer(prompt=prompt, config=config)
llm = GPT4ALLLlm(config=BaseLlmConfig(model="orca-mini-3b.ggmlv3.q4_0.bin"))
embedder = GPT4AllEmbedder(config=BaseEmbedderConfig(model="all-MiniLM-L6-v2"))
database = ChromaDB(config=chromadb_config)
@staticmethod
def _get_instance(model):
try:
from gpt4all import GPT4All
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The GPT4All python package is not installed. Please install it with `pip install embedchain[opensource]`" # noqa E501
) from None
return GPT4All(model)
def _get_gpt4all_answer(self, prompt: str, config: ChatConfig) -> Union[str, Iterable]:
if config.model and config.model != self.config.model:
raise RuntimeError(
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
)
if self.system_prompt or config.system_prompt:
raise ValueError("OpenSourceApp does not support `system_prompt`")
response = self.instance.generate(
prompt=prompt,
streaming=config.stream,
top_p=config.top_p,
max_tokens=config.max_tokens,
temp=config.temperature,
)
return response
super().__init__(config, llm=llm, db=database, embedder=embedder, system_prompt=system_prompt)