[Feature] Add support for NVIDIA AI LLMs and embedding models (#1293)

This commit is contained in:
Deshraj Yadav
2024-02-29 23:56:25 -08:00
committed by GitHub
parent 6518c0c06b
commit c77a75dfb5
18 changed files with 195 additions and 22 deletions

47
embedchain/llm/nvidia.py Normal file
View File

@@ -0,0 +1,47 @@
import os
from collections.abc import Iterable
from typing import Optional, Union
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError:
raise ImportError(
"NVIDIA AI endpoints requires extra dependencies. Install with `pip install langchain-nvidia-ai-endpoints`"
) from None
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class NvidiaLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "NVIDIA_API_KEY" not in os.environ:
raise ValueError("NVIDIA_API_KEY environment variable must be set")
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
model_kwargs = config.model_kwargs or {}
labels = model_kwargs.get("labels", None)
params = {"model": config.model}
if config.system_prompt:
params["system_prompt"] = config.system_prompt
if config.temperature:
params["temperature"] = config.temperature
if config.top_p:
params["top_p"] = config.top_p
if labels:
params["labels"] = labels
llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager))
return llm.invoke(prompt).content if labels is None else llm.invoke(prompt, labels=labels).content