Raise import error if Ollama and Google not found (#1432)

This commit is contained in:
Dev Khant
2024-06-19 10:16:48 +05:30
committed by GitHub
parent 21a04541ea
commit e3e107b31d
2 changed files with 9 additions and 11 deletions

View File

@@ -1,10 +1,12 @@
import importlib
import logging import logging
import os import os
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Any, Optional, Union
try:
import google.generativeai as genai import google.generativeai as genai
except ImportError:
raise ImportError("GoogleLlm requires extra dependencies. Install with `pip install google-generativeai`") from None
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@@ -16,14 +18,6 @@ logger = logging.getLogger(__name__)
@register_deserializable @register_deserializable
class GoogleLlm(BaseLlm): class GoogleLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
try:
importlib.import_module("google.generativeai")
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for GoogleLlm are not installed."
'Please install with `pip install --upgrade "embedchain[google]"`'
) from None
super().__init__(config) super().__init__(config)
if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ: if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ:
raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.") raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.")

View File

@@ -6,7 +6,11 @@ from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.llms.ollama import Ollama from langchain_community.llms.ollama import Ollama
try:
from ollama import Client from ollama import Client
except ImportError:
raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable