From 62c0c52e31a0e6e2ff4eb388f990d9e956abf5f7 Mon Sep 17 00:00:00 2001 From: Madison Ebersole Date: Mon, 8 Jan 2024 13:20:04 -0500 Subject: [PATCH] Add support for Hugging Face Inference Endpoint as LLM (#1143) --- docs/components/llms.mdx | 43 +++++++++++++++++++++++++++++++++++ embedchain/config/llm/base.py | 11 ++++++++- embedchain/llm/huggingface.py | 20 ++++++++++++++++ embedchain/utils/misc.py | 1 + tests/llm/test_huggingface.py | 19 ++++++++++++++++ 5 files changed, 93 insertions(+), 1 deletion(-) diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index 40e15c04..faf64855 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -494,6 +494,49 @@ llm: ``` +### Custom Endpoints + + +You can also use [Hugging Face Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index#-inference-endpoints) to access custom endpoints. First, set the `HUGGINGFACE_ACCESS_TOKEN` as above. + +Then, load the app using the config yaml file: + + + +```python main.py +import os +from embedchain import App + +os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "xxx" + +# load llm configuration from config.yaml file +app = App.from_config(config_path="config.yaml") +``` + +```yaml config.yaml +llm: + provider: huggingface + config: + endpoint: https://api-inference.huggingface.co/models/gpt2 # replace with your personal endpoint +``` + + +If your endpoint requires additional parameters, you can pass them in the `model_kwargs` field: + +``` +llm: + provider: huggingface + config: + endpoint: + model_kwargs: + max_new_tokens: 100 + temperature: 0.5 +``` + +Currently only supports `text-generation` and `text2text-generation` for now [[ref](https://api.python.langchain.com/en/latest/llms/langchain_community.llms.huggingface_endpoint.HuggingFaceEndpoint.html?highlight=huggingfaceendpoint#)]. + +See langchain's [hugging face endpoint](https://python.langchain.com/docs/integrations/chat/huggingface#huggingfaceendpoint) for more information. + ## Llama2 Llama2 is integrated through [Replicate](https://replicate.com/). Set `REPLICATE_API_TOKEN` in environment variable which you can obtain from [their platform](https://replicate.com/account/api-tokens). diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index 635bfd45..544cbb87 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -72,6 +72,8 @@ class BaseLlmConfig(BaseConfig): query_type: Optional[str] = None, callbacks: Optional[List] = None, api_key: Optional[str] = None, + endpoint: Optional[str] = None, + model_kwargs: Optional[Dict[str, Any]] = {}, ): """ Initializes a configuration class instance for the LLM. @@ -105,6 +107,12 @@ class BaseLlmConfig(BaseConfig): :type system_prompt: Optional[str], optional :param where: A dictionary of key-value pairs to filter the database results., defaults to None :type where: Dict[str, Any], optional + :param api_key: The api key of the custom endpoint, defaults to None + :type api_key: Optional[str], optional + :param endpoint: The api url of the custom endpoint, defaults to None + :type endpoint: Optional[str], optional + :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None + :type model_kwargs: Optional[Dict[str, Any]], optional :param callbacks: Langchain callback functions to use, defaults to None :type callbacks: Optional[List], optional :raises ValueError: If the template is not valid as template should @@ -132,7 +140,8 @@ class BaseLlmConfig(BaseConfig): self.query_type = query_type self.callbacks = callbacks self.api_key = api_key - + self.endpoint = endpoint + self.model_kwargs = model_kwargs if type(prompt) is str: prompt = Template(prompt) diff --git a/embedchain/llm/huggingface.py b/embedchain/llm/huggingface.py index 2d7601ee..67e24d6d 100644 --- a/embedchain/llm/huggingface.py +++ b/embedchain/llm/huggingface.py @@ -3,6 +3,7 @@ import logging import os from typing import Optional +from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain.llms.huggingface_hub import HuggingFaceHub from embedchain.config import BaseLlmConfig @@ -33,6 +34,15 @@ class HuggingFaceLlm(BaseLlm): @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> str: + if config.model: + return HuggingFaceLlm._from_model(prompt=prompt, config=config) + elif config.endpoint: + return HuggingFaceLlm._from_endpoint(prompt=prompt, config=config) + else: + raise ValueError("Either `model` or `endpoint` must be set") + + @staticmethod + def _from_model(prompt: str, config: BaseLlmConfig) -> str: model_kwargs = { "temperature": config.temperature or 0.1, "max_new_tokens": config.max_tokens, @@ -52,3 +62,13 @@ class HuggingFaceLlm(BaseLlm): ) return llm(prompt) + + @staticmethod + def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str: + llm = HuggingFaceEndpoint( + huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"], + endpoint_url=config.endpoint, + task="text-generation", + model_kwargs=config.model_kwargs, + ) + return llm(prompt) diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 03ed5bb3..263c143e 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -415,6 +415,7 @@ def validate_config(config_data): Optional("where"): dict, Optional("query_type"): str, Optional("api_key"): str, + Optional("endpoint"): str, }, }, Optional("vectordb"): { diff --git a/tests/llm/test_huggingface.py b/tests/llm/test_huggingface.py index c43b099e..45a70244 100644 --- a/tests/llm/test_huggingface.py +++ b/tests/llm/test_huggingface.py @@ -15,6 +15,14 @@ def huggingface_llm_config(): os.environ.pop("HUGGINGFACE_ACCESS_TOKEN") +@pytest.fixture +def huggingface_endpoint_config(): + os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token" + config = BaseLlmConfig(endpoint="https://api-inference.huggingface.co/models/gpt2", model_kwargs={"device": "cpu"}) + yield config + os.environ.pop("HUGGINGFACE_ACCESS_TOKEN") + + def test_init_raises_value_error_without_api_key(mocker): mocker.patch.dict(os.environ, clear=True) with pytest.raises(ValueError): @@ -61,3 +69,14 @@ def test_hugging_face_mock(huggingface_llm_config, mocker): assert answer == "Test answer" mock_llm_instance.assert_called_once_with("Test query") + + +def test_custom_endpoint(huggingface_endpoint_config, mocker): + mock_llm_instance = mocker.Mock(return_value="Test answer") + mocker.patch("embedchain.llm.huggingface.HuggingFaceEndpoint", return_value=mock_llm_instance) + + llm = HuggingFaceLlm(huggingface_endpoint_config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + mock_llm_instance.assert_called_once_with("Test query")