Add support for Hugging Face Inference Endpoint as LLM (#1143)

This commit is contained in:
Madison Ebersole
2024-01-08 13:20:04 -05:00
committed by GitHub
parent e36198dcc2
commit 62c0c52e31
5 changed files with 93 additions and 1 deletions

View File

@@ -3,6 +3,7 @@ import logging
import os
from typing import Optional
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain.llms.huggingface_hub import HuggingFaceHub
from embedchain.config import BaseLlmConfig
@@ -33,6 +34,15 @@ class HuggingFaceLlm(BaseLlm):
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
if config.model:
return HuggingFaceLlm._from_model(prompt=prompt, config=config)
elif config.endpoint:
return HuggingFaceLlm._from_endpoint(prompt=prompt, config=config)
else:
raise ValueError("Either `model` or `endpoint` must be set")
@staticmethod
def _from_model(prompt: str, config: BaseLlmConfig) -> str:
model_kwargs = {
"temperature": config.temperature or 0.1,
"max_new_tokens": config.max_tokens,
@@ -52,3 +62,13 @@ class HuggingFaceLlm(BaseLlm):
)
return llm(prompt)
@staticmethod
def _from_endpoint(prompt: str, config: BaseLlmConfig) -> str:
llm = HuggingFaceEndpoint(
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
endpoint_url=config.endpoint,
task="text-generation",
model_kwargs=config.model_kwargs,
)
return llm(prompt)