[Bug fix] import App shouldn't throw other llm deps errors (#837)

This commit is contained in:
Sidharth Mohanty
2023-10-26 09:18:53 +05:30
committed by GitHub
parent 413ccb83e6
commit a27eeb3255
2 changed files with 14 additions and 16 deletions

View File

@@ -8,6 +8,10 @@ from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
@register_deserializable
class Llama2Llm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
try: try:
importlib.import_module("replicate") importlib.import_module("replicate")
except ModuleNotFoundError: except ModuleNotFoundError:
@@ -15,11 +19,6 @@ except ModuleNotFoundError:
"The required dependencies for Llama2 are not installed." "The required dependencies for Llama2 are not installed."
'Please install with `pip install --upgrade "embedchain[llama2]"`' 'Please install with `pip install --upgrade "embedchain[llama2]"`'
) from None ) from None
@register_deserializable
class Llama2Llm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "REPLICATE_API_TOKEN" not in os.environ: if "REPLICATE_API_TOKEN" not in os.environ:
raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.") raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")

View File

@@ -6,6 +6,10 @@ from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
@register_deserializable
class VertexAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
try: try:
importlib.import_module("vertexai") importlib.import_module("vertexai")
except ModuleNotFoundError: except ModuleNotFoundError:
@@ -13,11 +17,6 @@ except ModuleNotFoundError:
"The required dependencies for VertexAI are not installed." "The required dependencies for VertexAI are not installed."
'Please install with `pip install --upgrade "embedchain[vertexai]"`' 'Please install with `pip install --upgrade "embedchain[vertexai]"`'
) from None ) from None
@register_deserializable
class VertexAILlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config) super().__init__(config=config)
def get_llm_model_answer(self, prompt): def get_llm_model_answer(self, prompt):