diff --git a/embedchain/llm/llama2.py b/embedchain/llm/llama2.py index 0c65ef53..d4f4aa2f 100644 --- a/embedchain/llm/llama2.py +++ b/embedchain/llm/llama2.py @@ -8,18 +8,17 @@ from embedchain.config import BaseLlmConfig from embedchain.helper.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm -try: - importlib.import_module("replicate") -except ModuleNotFoundError: - raise ModuleNotFoundError( - "The required dependencies for Llama2 are not installed." - 'Please install with `pip install --upgrade "embedchain[llama2]"`' - ) from None - @register_deserializable class Llama2Llm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): + try: + importlib.import_module("replicate") + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for Llama2 are not installed." + 'Please install with `pip install --upgrade "embedchain[llama2]"`' + ) from None if "REPLICATE_API_TOKEN" not in os.environ: raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.") diff --git a/embedchain/llm/vertex_ai.py b/embedchain/llm/vertex_ai.py index 98e81621..224c65f1 100644 --- a/embedchain/llm/vertex_ai.py +++ b/embedchain/llm/vertex_ai.py @@ -6,18 +6,17 @@ from embedchain.config import BaseLlmConfig from embedchain.helper.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm -try: - importlib.import_module("vertexai") -except ModuleNotFoundError: - raise ModuleNotFoundError( - "The required dependencies for VertexAI are not installed." - 'Please install with `pip install --upgrade "embedchain[vertexai]"`' - ) from None - @register_deserializable class VertexAILlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): + try: + importlib.import_module("vertexai") + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for VertexAI are not installed." + 'Please install with `pip install --upgrade "embedchain[vertexai]"`' + ) from None super().__init__(config=config) def get_llm_model_answer(self, prompt):