diff --git a/embedchain/llm/google.py b/embedchain/llm/google.py index 5409b713..c0002fa9 100644 --- a/embedchain/llm/google.py +++ b/embedchain/llm/google.py @@ -1,10 +1,12 @@ -import importlib import logging import os from collections.abc import Generator from typing import Any, Optional, Union -import google.generativeai as genai +try: + import google.generativeai as genai +except ImportError: + raise ImportError("GoogleLlm requires extra dependencies. Install with `pip install google-generativeai`") from None from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable @@ -16,14 +18,6 @@ logger = logging.getLogger(__name__) @register_deserializable class GoogleLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - try: - importlib.import_module("google.generativeai") - except ModuleNotFoundError: - raise ModuleNotFoundError( - "The required dependencies for GoogleLlm are not installed." - 'Please install with `pip install --upgrade "embedchain[google]"`' - ) from None - super().__init__(config) if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ: raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.") diff --git a/embedchain/llm/ollama.py b/embedchain/llm/ollama.py index c21c522d..e34ff38e 100644 --- a/embedchain/llm/ollama.py +++ b/embedchain/llm/ollama.py @@ -6,7 +6,11 @@ from langchain.callbacks.manager import CallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_community.llms.ollama import Ollama -from ollama import Client + +try: + from ollama import Client +except ImportError: + raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable