From 65a20aa457ab098a7ab3fbfc2167ca456bfd3429 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Wed, 18 Oct 2023 13:40:46 +0530 Subject: [PATCH] [Bug fix] Anthropic, Llama2 and VertexAI LLMs dependencies (#820) --- embedchain/llm/anthropic.py | 7 ++++++- embedchain/llm/llama2.py | 9 +++++++++ embedchain/llm/vertex_ai.py | 9 +++++++++ pyproject.toml | 4 ++++ tests/llm/{test_antrophic.py => test_anthrophic.py} | 10 ++++++++-- tests/test_factory.py | 4 ++++ 6 files changed, 40 insertions(+), 3 deletions(-) rename tests/llm/{test_antrophic.py => test_anthrophic.py} (86%) diff --git a/embedchain/llm/anthropic.py b/embedchain/llm/anthropic.py index 027da1ce..492854f6 100644 --- a/embedchain/llm/anthropic.py +++ b/embedchain/llm/anthropic.py @@ -1,4 +1,5 @@ import logging +import os from typing import Optional from embedchain.config import BaseLlmConfig @@ -9,6 +10,8 @@ from embedchain.llm.base import BaseLlm @register_deserializable class AnthropicLlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): + if "ANTHROPIC_API_KEY" not in os.environ: + raise ValueError("Please set the ANTHROPIC_API_KEY environment variable.") super().__init__(config=config) def get_llm_model_answer(self, prompt): @@ -18,7 +21,9 @@ class AnthropicLlm(BaseLlm): def _get_answer(prompt: str, config: BaseLlmConfig) -> str: from langchain.chat_models import ChatAnthropic - chat = ChatAnthropic(temperature=config.temperature, model=config.model) + chat = ChatAnthropic( + anthropic_api_key=os.environ["ANTHROPIC_API_KEY"], temperature=config.temperature, model=config.model + ) if config.max_tokens and config.max_tokens != 1000: logging.warning("Config option `max_tokens` is not supported by this model.") diff --git a/embedchain/llm/llama2.py b/embedchain/llm/llama2.py index f59a0b1e..2e5c8c05 100644 --- a/embedchain/llm/llama2.py +++ b/embedchain/llm/llama2.py @@ -1,3 +1,4 @@ +import importlib import os from typing import Optional @@ -7,6 +8,14 @@ from embedchain.config import BaseLlmConfig from embedchain.helper.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +try: + importlib.import_module("replicate") +except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for Llama2 are not installed." + 'Please install with `pip install --upgrade "embedchain[llama2]"`' + ) from None + @register_deserializable class Llama2Llm(BaseLlm): diff --git a/embedchain/llm/vertex_ai.py b/embedchain/llm/vertex_ai.py index f2adc9ae..98e81621 100644 --- a/embedchain/llm/vertex_ai.py +++ b/embedchain/llm/vertex_ai.py @@ -1,3 +1,4 @@ +import importlib import logging from typing import Optional @@ -5,6 +6,14 @@ from embedchain.config import BaseLlmConfig from embedchain.helper.json_serializable import register_deserializable from embedchain.llm.base import BaseLlm +try: + importlib.import_module("vertexai") +except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for VertexAI are not installed." + 'Please install with `pip install --upgrade "embedchain[vertexai]"`' + ) from None + @register_deserializable class VertexAILlm(BaseLlm): diff --git a/pyproject.toml b/pyproject.toml index b51a690f..d499f00a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,6 +121,8 @@ ftfy = { version = "6.1.1", optional = true } regex = { version = "2023.8.8", optional = true } huggingface_hub = { version = "^0.17.3", optional = true } pymilvus = { version="2.3.1", optional = true } +google-cloud-aiplatform = { version="^1.26.1", optional = true } +replicate = { version="^0.15.4", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -162,6 +164,8 @@ dataloaders=[ "unstructured", "sentence-transformers", ] +vertexai = ["google-cloud-aiplatform"] +llama2 = ["replicate"] [tool.poetry.group.docs.dependencies] diff --git a/tests/llm/test_antrophic.py b/tests/llm/test_anthrophic.py similarity index 86% rename from tests/llm/test_antrophic.py rename to tests/llm/test_anthrophic.py index 2bb74d35..fa281dbf 100644 --- a/tests/llm/test_antrophic.py +++ b/tests/llm/test_anthrophic.py @@ -1,3 +1,4 @@ +import os from unittest.mock import MagicMock, patch import pytest @@ -9,6 +10,7 @@ from embedchain.llm.anthropic import AnthropicLlm @pytest.fixture def anthropic_llm(): + os.environ["ANTHROPIC_API_KEY"] = "test_api_key" config = BaseLlmConfig(temperature=0.5, model="gpt2") return AnthropicLlm(config) @@ -31,7 +33,9 @@ def test_get_answer(anthropic_llm): assert response == "Test Response" mock_chat.assert_called_once_with( - temperature=anthropic_llm.config.temperature, model=anthropic_llm.config.model + anthropic_api_key="test_api_key", + temperature=anthropic_llm.config.temperature, + model=anthropic_llm.config.model, ) mock_chat_instance.assert_called_once_with( anthropic_llm._get_messages(prompt, system_prompt=anthropic_llm.config.system_prompt) @@ -60,6 +64,8 @@ def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog): response = anthropic_llm._get_answer(prompt, config) assert response == "Test Response" - mock_chat.assert_called_once_with(temperature=config.temperature, model=config.model) + mock_chat.assert_called_once_with( + anthropic_api_key="test_api_key", temperature=config.temperature, model=config.model + ) assert "Config option `max_tokens` is not supported by this model." in caplog.text diff --git a/tests/test_factory.py b/tests/test_factory.py index 1e3bcee5..cc2d0b04 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -1,3 +1,5 @@ +import os + import pytest import embedchain @@ -22,6 +24,8 @@ class TestFactories: ], ) def test_llm_factory_create(self, provider_name, config_data, expected_class): + os.environ["ANTHROPIC_API_KEY"] = "test_api_key" + os.environ["OPENAI_API_KEY"] = "test_api_key" llm_instance = LlmFactory.create(provider_name, config_data) assert isinstance(llm_instance, expected_class)