From e3e107b31d07d1cddc875ffbadd105d57398931a Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Wed, 19 Jun 2024 10:16:48 +0530 Subject: [PATCH] Raise import error if Ollama and Google not found (#1432) --- embedchain/llm/google.py | 14 ++++---------- embedchain/llm/ollama.py | 6 +++++- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/embedchain/llm/google.py b/embedchain/llm/google.py index 5409b713..c0002fa9 100644 --- a/embedchain/llm/google.py +++ b/embedchain/llm/google.py @@ -1,10 +1,12 @@ -import importlib import logging import os from collections.abc import Generator from typing import Any, Optional, Union -import google.generativeai as genai +try: + import google.generativeai as genai +except ImportError: + raise ImportError("GoogleLlm requires extra dependencies. Install with `pip install google-generativeai`") from None from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable @@ -16,14 +18,6 @@ logger = logging.getLogger(__name__) @register_deserializable class GoogleLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - try: - importlib.import_module("google.generativeai") - except ModuleNotFoundError: - raise ModuleNotFoundError( - "The required dependencies for GoogleLlm are not installed." - 'Please install with `pip install --upgrade "embedchain[google]"`' - ) from None - super().__init__(config) if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ: raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.") diff --git a/embedchain/llm/ollama.py b/embedchain/llm/ollama.py index c21c522d..e34ff38e 100644 --- a/embedchain/llm/ollama.py +++ b/embedchain/llm/ollama.py @@ -6,7 +6,11 @@ from langchain.callbacks.manager import CallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_community.llms.ollama import Ollama -from ollama import Client + +try: + from ollama import Client +except ImportError: + raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None from embedchain.config import BaseLlmConfig from embedchain.helpers.json_serializable import register_deserializable