From 0cb78b9067485b2a2652ab0048f3841fb4ff347c Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Tue, 10 Oct 2023 00:45:22 +0530 Subject: [PATCH] Add Hugging Face Hub LLM support (#762) --- embedchain/llm/hugging_face_hub.py | 51 ++++++++++++++++++++++++ pyproject.toml | 2 + tests/llm/test_hugging_face_hub.py | 64 ++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+) create mode 100644 embedchain/llm/hugging_face_hub.py create mode 100644 tests/llm/test_hugging_face_hub.py diff --git a/embedchain/llm/hugging_face_hub.py b/embedchain/llm/hugging_face_hub.py new file mode 100644 index 00000000..8a03623c --- /dev/null +++ b/embedchain/llm/hugging_face_hub.py @@ -0,0 +1,51 @@ +import importlib +import os +from typing import Optional + +from langchain.llms import HuggingFaceHub + +from embedchain.config import BaseLlmConfig +from embedchain.helper.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class HuggingFaceHubLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + if "HUGGINGFACEHUB_ACCESS_TOKEN" not in os.environ: + raise ValueError("Please set the HUGGINGFACEHUB_ACCESS_TOKEN environment variable.") + + try: + importlib.import_module("huggingface_hub") + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for HuggingFaceHub are not installed." + 'Please install with `pip install --upgrade "embedchain[huggingface_hub]"`' + ) from None + + super().__init__(config=config) + + def get_llm_model_answer(self, prompt): + if self.config.system_prompt: + raise ValueError("HuggingFaceHubLlm does not support `system_prompt`") + return HuggingFaceHubLlm._get_answer(prompt=prompt, config=self.config) + + @staticmethod + def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + model_kwargs = { + "temperature": config.temperature or 0.1, + "max_new_tokens": config.max_tokens, + } + + if config.top_p > 0.0 and config.top_p < 1.0: + model_kwargs["top_p"] = config.top_p + else: + raise ValueError("`top_p` must be > 0.0 and < 1.0") + + llm = HuggingFaceHub( + huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_ACCESS_TOKEN"], + repo_id=config.model or "google/flan-t5-xxl", + model_kwargs=model_kwargs, + ) + + return llm(prompt) diff --git a/pyproject.toml b/pyproject.toml index bc9c47cf..dc18df19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ pillow = { version = "10.0.1", optional = true } torchvision = { version = ">=0.15.1, !=0.15.2", optional = true } ftfy = { version = "6.1.1", optional = true } regex = { version = "2023.8.8", optional = true } +huggingface_hub = { version = "^0.17.3", optional = true } [tool.poetry.group.dev.dependencies] black = "^23.3.0" @@ -136,6 +137,7 @@ discord = ["discord"] slack = ["slack-sdk", "flask"] whatsapp = ["twilio", "flask"] images = ["torch", "ftfy", "regex", "pillow", "torchvision"] +huggingface_hub=["huggingface_hub"] cohere = ["cohere"] [tool.poetry.group.docs.dependencies] diff --git a/tests/llm/test_hugging_face_hub.py b/tests/llm/test_hugging_face_hub.py new file mode 100644 index 00000000..63b1bfed --- /dev/null +++ b/tests/llm/test_hugging_face_hub.py @@ -0,0 +1,64 @@ +import importlib +import os +import unittest +from unittest.mock import patch, MagicMock + +from embedchain.config import BaseLlmConfig +from embedchain.llm.hugging_face_hub import HuggingFaceHubLlm + + +class TestHuggingFaceHubLlm(unittest.TestCase): + def setUp(self): + os.environ["HUGGINGFACEHUB_ACCESS_TOKEN"] = "test_access_token" + self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8) + + def test_init_raises_value_error_without_api_key(self): + os.environ.pop("HUGGINGFACEHUB_ACCESS_TOKEN") + with self.assertRaises(ValueError): + HuggingFaceHubLlm() + + def test_get_llm_model_answer_raises_value_error_for_system_prompt(self): + llm = HuggingFaceHubLlm(self.config) + llm.config.system_prompt = "system_prompt" + with self.assertRaises(ValueError): + llm.get_llm_model_answer("prompt") + + def test_top_p_value_within_range(self): + config = BaseLlmConfig(top_p=1.0) + with self.assertRaises(ValueError): + HuggingFaceHubLlm._get_answer("test_prompt", config) + + def test_dependency_is_imported(self): + importlib_installed = True + try: + importlib.import_module("huggingface_hub") + except ImportError: + importlib_installed = False + self.assertTrue(importlib_installed) + + @patch("embedchain.llm.hugging_face_hub.HuggingFaceHubLlm._get_answer") + def test_get_llm_model_answer(self, mock_get_answer): + mock_get_answer.return_value = "Test answer" + + llm = HuggingFaceHubLlm(self.config) + answer = llm.get_llm_model_answer("Test query") + + self.assertEqual(answer, "Test answer") + mock_get_answer.assert_called_once() + + @patch("embedchain.llm.hugging_face_hub.HuggingFaceHub") + def test_hugging_face_mock(self, mock_hugging_face_hub): + mock_llm_instance = MagicMock() + mock_llm_instance.return_value = "Test answer" + mock_hugging_face_hub.return_value = mock_llm_instance + + llm = HuggingFaceHubLlm(self.config) + answer = llm.get_llm_model_answer("Test query") + + self.assertEqual(answer, "Test answer") + mock_hugging_face_hub.assert_called_once_with( + huggingfacehub_api_token="test_access_token", + repo_id="google/flan-t5-xxl", + model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8}, + ) + mock_llm_instance.assert_called_once_with("Test query")