[Feature] Add support for NVIDIA AI LLMs and embedding models (#1293)

This commit is contained in:
Deshraj Yadav
2024-02-29 23:56:25 -08:00
committed by GitHub
parent 6518c0c06b
commit c77a75dfb5
18 changed files with 195 additions and 22 deletions

View File

@@ -84,7 +84,7 @@ class App(EmbedChain):
if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.")
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
# Initialize the metadata db for the app

View File

@@ -36,7 +36,6 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
defaults to None
:type collection_name: Optional[str], optional
"""
self._setup_logging(log_level)
self.id = id
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
self.collection_name = collection_name
@@ -52,12 +51,6 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
return
def _setup_logging(self, debug_level):
level = logging.WARNING # Default level
if debug_level is not None:
level = getattr(logging, debug_level.upper(), None)
if not isinstance(level, int):
raise ValueError(f"Invalid log level: {debug_level}")
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
def _setup_logging(self, log_level):
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
self.logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,26 @@
import logging
import os
from typing import Optional
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class NvidiaEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
if "NVIDIA_API_KEY" not in os.environ:
raise ValueError("NVIDIA_API_KEY environment variable must be set")
super().__init__(config=config)
model = self.config.model or "nvolveqa_40k"
logging.info(f"Using NVIDIA embedding model: {model}")
embedder = NVIDIAEmbeddings(model=model)
embedding_fn = BaseEmbedder._langchain_default_concept(embedder)
self.set_embedding_fn(embedding_fn=embedding_fn)
vector_dimension = self.config.vector_dimension or VectorDimensions.NVIDIA_AI.value
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -24,6 +24,7 @@ class LlmFactory:
"aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
"groq": "embedchain.llm.groq.GroqLlm",
"nvidia": "embedchain.llm.nvidia.NvidiaLlm",
}
provider_to_config_class = {
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
@@ -54,13 +55,14 @@ class EmbedderFactory:
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
"google": "embedchain.embedder.google.GoogleAIEmbedder",
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
"google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
}
@classmethod

47
embedchain/llm/nvidia.py Normal file
View File

@@ -0,0 +1,47 @@
import os
from collections.abc import Iterable
from typing import Optional, Union
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError:
raise ImportError(
"NVIDIA AI endpoints requires extra dependencies. Install with `pip install langchain-nvidia-ai-endpoints`"
) from None
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class NvidiaLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
if "NVIDIA_API_KEY" not in os.environ:
raise ValueError("NVIDIA_API_KEY environment variable must be set")
super().__init__(config=config)
def get_llm_model_answer(self, prompt):
return self._get_answer(prompt=prompt, config=self.config)
@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
model_kwargs = config.model_kwargs or {}
labels = model_kwargs.get("labels", None)
params = {"model": config.model}
if config.system_prompt:
params["system_prompt"] = config.system_prompt
if config.temperature:
params["temperature"] = config.temperature
if config.top_p:
params["top_p"] = config.top_p
if labels:
params["labels"] = labels
llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager))
return llm.invoke(prompt).content if labels is None else llm.invoke(prompt, labels=labels).content

View File

@@ -9,3 +9,4 @@ class VectorDimensions(Enum):
HUGGING_FACE = 384
GOOGLE_AI = 768
MISTRAL_AI = 1024
NVIDIA_AI = 1024

View File

@@ -17,8 +17,6 @@ from embedchain.models.data_type import DataType
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils.misc import detect_datatype
logging.basicConfig(level=logging.WARN)
# Set up the user directory if it doesn't exist already
Client.setup()
@@ -33,7 +31,7 @@ class OpenAIAssistant:
model="gpt-4-1106-preview",
data_sources=None,
assistant_id=None,
log_level=logging.WARN,
log_level=logging.INFO,
collect_metrics=True,
):
self.name = name or "OpenAI Assistant"
@@ -156,10 +154,9 @@ class AIAssistant:
assistant_id=None,
thread_id=None,
data_sources=None,
log_level=logging.WARN,
log_level=logging.INFO,
collect_metrics=True,
):
logging.basicConfig(level=log_level)
self.name = name or "AI Assistant"
self.data_sources = data_sources or []

View File

@@ -12,8 +12,6 @@ HOME_DIR = str(Path.home())
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
logger = logging.getLogger(__name__)
class AnonymousTelemetry:
def __init__(self, host="https://app.posthog.com", enabled=True):
@@ -63,4 +61,4 @@ class AnonymousTelemetry:
try:
self.posthog.capture(self.user_id, event_name, properties)
except Exception:
logger.exception(f"Failed to send telemetry {event_name=}")
logging.exception(f"Failed to send telemetry {event_name=}")

View File

@@ -407,6 +407,7 @@ def validate_config(config_data):
"mistralai",
"vllm",
"groq",
"nvidia",
),
Optional("config"): {
Optional("model"): str,
@@ -443,6 +444,7 @@ def validate_config(config_data):
"azure_openai",
"google",
"mistralai",
"nvidia",
),
Optional("config"): {
Optional("model"): Optional(str),
@@ -462,6 +464,7 @@ def validate_config(config_data):
"azure_openai",
"google",
"mistralai",
"nvidia",
),
Optional("config"): {
Optional("model"): str,