diff --git a/embedchain/llm/anthropic.py b/embedchain/llm/anthropic.py index a6874e47..f8de89dd 100644 --- a/embedchain/llm/anthropic.py +++ b/embedchain/llm/anthropic.py @@ -17,18 +17,17 @@ logger = logging.getLogger(__name__) @register_deserializable class AnthropicLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - if "ANTHROPIC_API_KEY" not in os.environ: - raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.") super().__init__(config=config) + if not self.config.api_key and "ANTHROPIC_API_KEY" not in os.environ: + raise ValueError("Please set the ANTHROPIC_API_KEY environment variable or pass it in the config.") def get_llm_model_answer(self, prompt): return AnthropicLlm._get_answer(prompt=prompt, config=self.config) @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> str: - chat = ChatAnthropic( - anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model_name=config.model - ) + api_key = config.api_key or os.getenv("ANTHROPIC_API_KEY") + chat = ChatAnthropic(anthropic_api_key=api_key, temperature=config.temperature, model_name=config.model) if config.max_tokens and config.max_tokens != 1000: logger.warning("Config option `max_tokens` is not supported by this model.") diff --git a/embedchain/llm/cohere.py b/embedchain/llm/cohere.py index d755db0c..a1ad1c8d 100644 --- a/embedchain/llm/cohere.py +++ b/embedchain/llm/cohere.py @@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm @register_deserializable class CohereLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - if "COHERE_API_KEY" not in os.environ: - raise ValueError("Please set the COHERE_API_KEY environment variable.") - try: importlib.import_module("cohere") except ModuleNotFoundError: @@ -24,6 +21,8 @@ class CohereLlm(BaseLlm): ) from None super().__init__(config=config) + if not self.config.api_key and "COHERE_API_KEY" not in os.environ: + raise ValueError("Please set the COHERE_API_KEY environment variable or pass it in the config.") def get_llm_model_answer(self, prompt): if self.config.system_prompt: @@ -32,8 +31,9 @@ class CohereLlm(BaseLlm): @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + api_key = config.api_key or os.getenv("COHERE_API_KEY") llm = Cohere( - cohere_api_key=os.environ["COHERE_API_KEY"], + cohere_api_key=api_key, model=config.model, max_tokens=config.max_tokens, temperature=config.temperature, diff --git a/embedchain/llm/google.py b/embedchain/llm/google.py index de15d4c4..5409b713 100644 --- a/embedchain/llm/google.py +++ b/embedchain/llm/google.py @@ -16,9 +16,6 @@ logger = logging.getLogger(__name__) @register_deserializable class GoogleLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - if "GOOGLE_API_KEY" not in os.environ: - raise ValueError("Please set the GOOGLE_API_KEY environment variable.") - try: importlib.import_module("google.generativeai") except ModuleNotFoundError: @@ -28,7 +25,11 @@ class GoogleLlm(BaseLlm): ) from None super().__init__(config) - genai.configure(api_key=os.environ["GOOGLE_API_KEY"]) + if not self.config.api_key and "GOOGLE_API_KEY" not in os.environ: + raise ValueError("Please set the GOOGLE_API_KEY environment variable or pass it in the config.") + + api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY") + genai.configure(api_key=api_key) def get_llm_model_answer(self, prompt): if self.config.system_prompt: diff --git a/embedchain/llm/groq.py b/embedchain/llm/groq.py index d658be15..756b1bd5 100644 --- a/embedchain/llm/groq.py +++ b/embedchain/llm/groq.py @@ -19,6 +19,8 @@ from embedchain.llm.base import BaseLlm class GroqLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config=config) + if not self.config.api_key and "GROQ_API_KEY" not in os.environ: + raise ValueError("Please set the GROQ_API_KEY environment variable or pass it in the config.") def get_llm_model_answer(self, prompt) -> str: response = self._get_answer(prompt, self.config) diff --git a/embedchain/llm/huggingface.py b/embedchain/llm/huggingface.py index 69f0c463..72d2f4cc 100644 --- a/embedchain/llm/huggingface.py +++ b/embedchain/llm/huggingface.py @@ -17,9 +17,6 @@ logger = logging.getLogger(__name__) @register_deserializable class HuggingFaceLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - if "HUGGINGFACE_ACCESS_TOKEN" not in os.environ: - raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable.") - try: importlib.import_module("huggingface_hub") except ModuleNotFoundError: @@ -29,6 +26,8 @@ class HuggingFaceLlm(BaseLlm): ) from None super().__init__(config=config) + if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ: + raise ValueError("Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass it in the config.") def get_llm_model_answer(self, prompt): if self.config.system_prompt: @@ -60,9 +59,10 @@ class HuggingFaceLlm(BaseLlm): raise ValueError("`top_p` must be > 0.0 and < 1.0") model = config.model + api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN") logger.info(f"Using HuggingFaceHub with model {model}") llm = HuggingFaceHub( - huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"], + huggingfacehub_api_token=api_key, repo_id=model, model_kwargs=model_kwargs, ) @@ -70,8 +70,9 @@ class HuggingFaceLlm(BaseLlm): @staticmethod def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str: + api_key = config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN") llm = HuggingFaceEndpoint( - huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"], + huggingfacehub_api_token=api_key, endpoint_url=config.endpoint, task="text-generation", model_kwargs=config.model_kwargs, diff --git a/embedchain/llm/jina.py b/embedchain/llm/jina.py index 4925276e..782742cb 100644 --- a/embedchain/llm/jina.py +++ b/embedchain/llm/jina.py @@ -12,9 +12,9 @@ from embedchain.llm.base import BaseLlm @register_deserializable class JinaLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - if "JINACHAT_API_KEY" not in os.environ: - raise ValueError("Please set the JINACHAT_API_KEY environment variable.") super().__init__(config=config) + if not self.config.api_key and "JINACHAT_API_KEY" not in os.environ: + raise ValueError("Please set the JINACHAT_API_KEY environment variable or pass it in the config.") def get_llm_model_answer(self, prompt): response = JinaLlm._get_answer(prompt, self.config) @@ -29,13 +29,13 @@ class JinaLlm(BaseLlm): kwargs = { "temperature": config.temperature, "max_tokens": config.max_tokens, + "jinachat_api_key": config.api_key or os.environ["JINACHAT_API_KEY"], "model_kwargs": {}, } if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: - from langchain.callbacks.streaming_stdout import \ - StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler chat = JinaChat(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) else: diff --git a/embedchain/llm/llama2.py b/embedchain/llm/llama2.py index 426239a8..8a82f3f7 100644 --- a/embedchain/llm/llama2.py +++ b/embedchain/llm/llama2.py @@ -19,8 +19,6 @@ class Llama2Llm(BaseLlm): "The required dependencies for Llama2 are not installed." 'Please install with `pip install --upgrade "embedchain[llama2]"`' ) from None - if "REPLICATE_API_TOKEN" not in os.environ: - raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.") # Set default config values specific to this llm if not config: @@ -35,13 +33,17 @@ class Llama2Llm(BaseLlm): ) super().__init__(config=config) + if not self.config.api_key and "REPLICATE_API_TOKEN" not in os.environ: + raise ValueError("Please set the REPLICATE_API_TOKEN environment variable or pass it in the config.") def get_llm_model_answer(self, prompt): # TODO: Move the model and other inputs into config if self.config.system_prompt: raise ValueError("Llama2 does not support `system_prompt`") + api_key = self.config.api_key or os.getenv("REPLICATE_API_TOKEN") llm = Replicate( model=self.config.model, + replicate_api_token=api_key, input={ "temperature": self.config.temperature, "max_length": self.config.max_tokens, diff --git a/embedchain/llm/nvidia.py b/embedchain/llm/nvidia.py index 4a88cb1e..aac4d3b0 100644 --- a/embedchain/llm/nvidia.py +++ b/embedchain/llm/nvidia.py @@ -21,10 +21,9 @@ from embedchain.llm.base import BaseLlm @register_deserializable class NvidiaLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - if "NVIDIA_API_KEY" not in os.environ: - raise ValueError("NVIDIA_API_KEY environment variable must be set") - super().__init__(config=config) + if not self.config.api_key and "NVIDIA_API_KEY" not in os.environ: + raise ValueError("Please set the NVIDIA_API_KEY environment variable or pass it in the config.") def get_llm_model_answer(self, prompt): return self._get_answer(prompt=prompt, config=self.config) @@ -34,7 +33,7 @@ class NvidiaLlm(BaseLlm): callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()] model_kwargs = config.model_kwargs or {} labels = model_kwargs.get("labels", None) - params = {"model": config.model} + params = {"model": config.model, "nvidia_api_key": config.api_key or os.getenv("NVIDIA_API_KEY")} if config.system_prompt: params["system_prompt"] = config.system_prompt if config.temperature: diff --git a/embedchain/llm/together.py b/embedchain/llm/together.py index 17995ca5..9c3045d3 100644 --- a/embedchain/llm/together.py +++ b/embedchain/llm/together.py @@ -12,9 +12,6 @@ from embedchain.llm.base import BaseLlm @register_deserializable class TogetherLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): - if "TOGETHER_API_KEY" not in os.environ: - raise ValueError("Please set the TOGETHER_API_KEY environment variable.") - try: importlib.import_module("together") except ModuleNotFoundError: @@ -24,6 +21,8 @@ class TogetherLlm(BaseLlm): ) from None super().__init__(config=config) + if not self.config.api_key and "TOGETHER_API_KEY" not in os.environ: + raise ValueError("Please set the TOGETHER_API_KEY environment variable or pass it in the config.") def get_llm_model_answer(self, prompt): if self.config.system_prompt: @@ -32,8 +31,9 @@ class TogetherLlm(BaseLlm): @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + api_key = config.api_key or os.getenv("TOGETHER_API_KEY") llm = Together( - together_api_key=os.environ["TOGETHER_API_KEY"], + together_api_key=api_key, model=config.model, max_tokens=config.max_tokens, temperature=config.temperature, diff --git a/tests/llm/test_jina.py b/tests/llm/test_jina.py index 4639c410..8df93322 100644 --- a/tests/llm/test_jina.py +++ b/tests/llm/test_jina.py @@ -74,5 +74,6 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker): mocked_jinachat.assert_called_once_with( temperature=config.temperature, max_tokens=config.max_tokens, + jinachat_api_key=os.environ["JINACHAT_API_KEY"], model_kwargs={"top_p": config.top_p}, )