[Feature] Add support for NVIDIA AI LLMs and embedding models (#1293)
This commit is contained in:
@@ -13,6 +13,7 @@ Embedchain supports several embedding models from the following providers:
|
|||||||
<Card title="GPT4All" href="#gpt4all"></Card>
|
<Card title="GPT4All" href="#gpt4all"></Card>
|
||||||
<Card title="Hugging Face" href="#hugging-face"></Card>
|
<Card title="Hugging Face" href="#hugging-face"></Card>
|
||||||
<Card title="Vertex AI" href="#vertex-ai"></Card>
|
<Card title="Vertex AI" href="#vertex-ai"></Card>
|
||||||
|
<Card title="NVIDIA AI" href="#nvidia-ai"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## OpenAI
|
## OpenAI
|
||||||
@@ -220,3 +221,55 @@ embedder:
|
|||||||
```
|
```
|
||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
|
## NVIDIA AI
|
||||||
|
|
||||||
|
[NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/) let you quickly use NVIDIA's AI models, such as Mixtral 8x7B, Llama 2 etc, through our API. These models are available in the [NVIDIA NGC catalog](https://catalog.ngc.nvidia.com/ai-foundation-models), fully optimized and ready to use on NVIDIA's AI platform. They are designed for high speed and easy customization, ensuring smooth performance on any accelerated setup.
|
||||||
|
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
In order to use embedding models and LLMs from NVIDIA AI, create an account on [NVIDIA NGC Service](https://catalog.ngc.nvidia.com/).
|
||||||
|
|
||||||
|
Generate an API key from their dashboard. Set the API key as `NVIDIA_API_KEY` environment variable. Note that the `NVIDIA_API_KEY` will start with `nvapi-`.
|
||||||
|
|
||||||
|
Below is an example of how to use LLM model and embedding model from NVIDIA AI:
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
|
||||||
|
```python main.py
|
||||||
|
import os
|
||||||
|
from embedchain import App
|
||||||
|
|
||||||
|
os.environ['NVIDIA_API_KEY'] = 'nvapi-xxxx'
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"app": {
|
||||||
|
"config": {
|
||||||
|
"id": "my-app",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"provider": "nvidia",
|
||||||
|
"config": {
|
||||||
|
"model": "nemotron_steerlm_8b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"embedder": {
|
||||||
|
"provider": "nvidia",
|
||||||
|
"config": {
|
||||||
|
"model": "nvolveqa_40k",
|
||||||
|
"vector_dimension": 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app = App.from_config(config=config)
|
||||||
|
|
||||||
|
app.add("https://www.forbes.com/profile/elon-musk")
|
||||||
|
answer = app.query("What is the net worth of Elon Musk today?")
|
||||||
|
# Answer: The net worth of Elon Musk is subject to fluctuations based on the market value of his holdings in various companies.
|
||||||
|
# As of March 1, 2024, his net worth is estimated to be approximately $210 billion. However, this figure can change rapidly due to stock market fluctuations and other factors.
|
||||||
|
# Additionally, his net worth may include other assets such as real estate and art, which are not reflected in his stock portfolio.
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ Embedchain comes with built-in support for various popular large language models
|
|||||||
<Card title="Mistral AI" href="#mistral-ai"></Card>
|
<Card title="Mistral AI" href="#mistral-ai"></Card>
|
||||||
<Card title="AWS Bedrock" href="#aws-bedrock"></Card>
|
<Card title="AWS Bedrock" href="#aws-bedrock"></Card>
|
||||||
<Card title="Groq" href="#groq"></Card>
|
<Card title="Groq" href="#groq"></Card>
|
||||||
|
<Card title="NVIDIA AI" href="#nvidia-ai"></Card>
|
||||||
</CardGroup>
|
</CardGroup>
|
||||||
|
|
||||||
## OpenAI
|
## OpenAI
|
||||||
@@ -730,6 +731,58 @@ app.query("Write a poem about Embedchain")
|
|||||||
```
|
```
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|
||||||
|
## NVIDIA AI
|
||||||
|
|
||||||
|
[NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/) let you quickly use NVIDIA's AI models, such as Mixtral 8x7B, Llama 2 etc, through our API. These models are available in the [NVIDIA NGC catalog](https://catalog.ngc.nvidia.com/ai-foundation-models), fully optimized and ready to use on NVIDIA's AI platform. They are designed for high speed and easy customization, ensuring smooth performance on any accelerated setup.
|
||||||
|
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
In order to use LLMs from NVIDIA AI, create an account on [NVIDIA NGC Service](https://catalog.ngc.nvidia.com/).
|
||||||
|
|
||||||
|
Generate an API key from their dashboard. Set the API key as `NVIDIA_API_KEY` environment variable. Note that the `NVIDIA_API_KEY` will start with `nvapi-`.
|
||||||
|
|
||||||
|
Below is an example of how to use LLM model and embedding model from NVIDIA AI:
|
||||||
|
|
||||||
|
<CodeGroup>
|
||||||
|
|
||||||
|
```python main.py
|
||||||
|
import os
|
||||||
|
from embedchain import App
|
||||||
|
|
||||||
|
os.environ['NVIDIA_API_KEY'] = 'nvapi-xxxx'
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"app": {
|
||||||
|
"config": {
|
||||||
|
"id": "my-app",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"provider": "nvidia",
|
||||||
|
"config": {
|
||||||
|
"model": "nemotron_steerlm_8b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"embedder": {
|
||||||
|
"provider": "nvidia",
|
||||||
|
"config": {
|
||||||
|
"model": "nvolveqa_40k",
|
||||||
|
"vector_dimension": 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app = App.from_config(config=config)
|
||||||
|
|
||||||
|
app.add("https://www.forbes.com/profile/elon-musk")
|
||||||
|
answer = app.query("What is the net worth of Elon Musk today?")
|
||||||
|
# Answer: The net worth of Elon Musk is subject to fluctuations based on the market value of his holdings in various companies.
|
||||||
|
# As of March 1, 2024, his net worth is estimated to be approximately $210 billion. However, this figure can change rapidly due to stock market fluctuations and other factors.
|
||||||
|
# Additionally, his net worth may include other assets such as real estate and art, which are not reflected in his stock portfolio.
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
|
||||||
<br/ >
|
<br/ >
|
||||||
|
|
||||||
<Snippet file="missing-llm-tip.mdx" />
|
<Snippet file="missing-llm-tip.mdx" />
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class App(EmbedChain):
|
|||||||
if name and config:
|
if name and config:
|
||||||
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
raise Exception("Cannot provide both name and config. Please provide only one of them.")
|
||||||
|
|
||||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
# logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Initialize the metadata db for the app
|
# Initialize the metadata db for the app
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
|||||||
defaults to None
|
defaults to None
|
||||||
:type collection_name: Optional[str], optional
|
:type collection_name: Optional[str], optional
|
||||||
"""
|
"""
|
||||||
self._setup_logging(log_level)
|
|
||||||
self.id = id
|
self.id = id
|
||||||
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
@@ -52,12 +51,6 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
|
|||||||
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
|
logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
|
||||||
return
|
return
|
||||||
|
|
||||||
def _setup_logging(self, debug_level):
|
def _setup_logging(self, log_level):
|
||||||
level = logging.WARNING # Default level
|
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
|
||||||
if debug_level is not None:
|
|
||||||
level = getattr(logging, debug_level.upper(), None)
|
|
||||||
if not isinstance(level, int):
|
|
||||||
raise ValueError(f"Invalid log level: {debug_level}")
|
|
||||||
|
|
||||||
logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
|
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|||||||
26
embedchain/embedder/nvidia.py
Normal file
26
embedchain/embedder/nvidia.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
||||||
|
|
||||||
|
from embedchain.config import BaseEmbedderConfig
|
||||||
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
|
from embedchain.models import VectorDimensions
|
||||||
|
|
||||||
|
|
||||||
|
class NvidiaEmbedder(BaseEmbedder):
|
||||||
|
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||||
|
if "NVIDIA_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("NVIDIA_API_KEY environment variable must be set")
|
||||||
|
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
model = self.config.model or "nvolveqa_40k"
|
||||||
|
logging.info(f"Using NVIDIA embedding model: {model}")
|
||||||
|
embedder = NVIDIAEmbeddings(model=model)
|
||||||
|
embedding_fn = BaseEmbedder._langchain_default_concept(embedder)
|
||||||
|
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||||
|
|
||||||
|
vector_dimension = self.config.vector_dimension or VectorDimensions.NVIDIA_AI.value
|
||||||
|
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||||
@@ -24,6 +24,7 @@ class LlmFactory:
|
|||||||
"aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
|
"aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
|
||||||
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
|
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
|
||||||
"groq": "embedchain.llm.groq.GroqLlm",
|
"groq": "embedchain.llm.groq.GroqLlm",
|
||||||
|
"nvidia": "embedchain.llm.nvidia.NvidiaLlm",
|
||||||
}
|
}
|
||||||
provider_to_config_class = {
|
provider_to_config_class = {
|
||||||
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
"embedchain": "embedchain.config.llm.base.BaseLlmConfig",
|
||||||
@@ -54,13 +55,14 @@ class EmbedderFactory:
|
|||||||
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
|
"vertexai": "embedchain.embedder.vertexai.VertexAIEmbedder",
|
||||||
"google": "embedchain.embedder.google.GoogleAIEmbedder",
|
"google": "embedchain.embedder.google.GoogleAIEmbedder",
|
||||||
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
"mistralai": "embedchain.embedder.mistralai.MistralAIEmbedder",
|
||||||
|
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
|
||||||
}
|
}
|
||||||
provider_to_config_class = {
|
provider_to_config_class = {
|
||||||
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
|
||||||
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
|
||||||
"google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
|
"google": "embedchain.config.embedder.google.GoogleAIEmbedderConfig",
|
||||||
|
"gpt4all": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
"huggingface": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
|
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
47
embedchain/llm/nvidia.py
Normal file
47
embedchain/llm/nvidia.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||||
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"NVIDIA AI endpoints requires extra dependencies. Install with `pip install langchain-nvidia-ai-endpoints`"
|
||||||
|
) from None
|
||||||
|
|
||||||
|
from embedchain.config import BaseLlmConfig
|
||||||
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
from embedchain.llm.base import BaseLlm
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class NvidiaLlm(BaseLlm):
|
||||||
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
|
if "NVIDIA_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("NVIDIA_API_KEY environment variable must be set")
|
||||||
|
|
||||||
|
super().__init__(config=config)
|
||||||
|
|
||||||
|
def get_llm_model_answer(self, prompt):
|
||||||
|
return self._get_answer(prompt=prompt, config=self.config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
|
||||||
|
callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
|
||||||
|
model_kwargs = config.model_kwargs or {}
|
||||||
|
labels = model_kwargs.get("labels", None)
|
||||||
|
params = {"model": config.model}
|
||||||
|
if config.system_prompt:
|
||||||
|
params["system_prompt"] = config.system_prompt
|
||||||
|
if config.temperature:
|
||||||
|
params["temperature"] = config.temperature
|
||||||
|
if config.top_p:
|
||||||
|
params["top_p"] = config.top_p
|
||||||
|
if labels:
|
||||||
|
params["labels"] = labels
|
||||||
|
llm = ChatNVIDIA(**params, callback_manager=CallbackManager(callback_manager))
|
||||||
|
return llm.invoke(prompt).content if labels is None else llm.invoke(prompt, labels=labels).content
|
||||||
@@ -9,3 +9,4 @@ class VectorDimensions(Enum):
|
|||||||
HUGGING_FACE = 384
|
HUGGING_FACE = 384
|
||||||
GOOGLE_AI = 768
|
GOOGLE_AI = 768
|
||||||
MISTRAL_AI = 1024
|
MISTRAL_AI = 1024
|
||||||
|
NVIDIA_AI = 1024
|
||||||
|
|||||||
@@ -17,8 +17,6 @@ from embedchain.models.data_type import DataType
|
|||||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||||
from embedchain.utils.misc import detect_datatype
|
from embedchain.utils.misc import detect_datatype
|
||||||
|
|
||||||
logging.basicConfig(level=logging.WARN)
|
|
||||||
|
|
||||||
# Set up the user directory if it doesn't exist already
|
# Set up the user directory if it doesn't exist already
|
||||||
Client.setup()
|
Client.setup()
|
||||||
|
|
||||||
@@ -33,7 +31,7 @@ class OpenAIAssistant:
|
|||||||
model="gpt-4-1106-preview",
|
model="gpt-4-1106-preview",
|
||||||
data_sources=None,
|
data_sources=None,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
log_level=logging.WARN,
|
log_level=logging.INFO,
|
||||||
collect_metrics=True,
|
collect_metrics=True,
|
||||||
):
|
):
|
||||||
self.name = name or "OpenAI Assistant"
|
self.name = name or "OpenAI Assistant"
|
||||||
@@ -156,10 +154,9 @@ class AIAssistant:
|
|||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
thread_id=None,
|
thread_id=None,
|
||||||
data_sources=None,
|
data_sources=None,
|
||||||
log_level=logging.WARN,
|
log_level=logging.INFO,
|
||||||
collect_metrics=True,
|
collect_metrics=True,
|
||||||
):
|
):
|
||||||
logging.basicConfig(level=log_level)
|
|
||||||
|
|
||||||
self.name = name or "AI Assistant"
|
self.name = name or "AI Assistant"
|
||||||
self.data_sources = data_sources or []
|
self.data_sources = data_sources or []
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ HOME_DIR = str(Path.home())
|
|||||||
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AnonymousTelemetry:
|
class AnonymousTelemetry:
|
||||||
def __init__(self, host="https://app.posthog.com", enabled=True):
|
def __init__(self, host="https://app.posthog.com", enabled=True):
|
||||||
@@ -63,4 +61,4 @@ class AnonymousTelemetry:
|
|||||||
try:
|
try:
|
||||||
self.posthog.capture(self.user_id, event_name, properties)
|
self.posthog.capture(self.user_id, event_name, properties)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to send telemetry {event_name=}")
|
logging.exception(f"Failed to send telemetry {event_name=}")
|
||||||
|
|||||||
@@ -407,6 +407,7 @@ def validate_config(config_data):
|
|||||||
"mistralai",
|
"mistralai",
|
||||||
"vllm",
|
"vllm",
|
||||||
"groq",
|
"groq",
|
||||||
|
"nvidia",
|
||||||
),
|
),
|
||||||
Optional("config"): {
|
Optional("config"): {
|
||||||
Optional("model"): str,
|
Optional("model"): str,
|
||||||
@@ -443,6 +444,7 @@ def validate_config(config_data):
|
|||||||
"azure_openai",
|
"azure_openai",
|
||||||
"google",
|
"google",
|
||||||
"mistralai",
|
"mistralai",
|
||||||
|
"nvidia",
|
||||||
),
|
),
|
||||||
Optional("config"): {
|
Optional("config"): {
|
||||||
Optional("model"): Optional(str),
|
Optional("model"): Optional(str),
|
||||||
@@ -462,6 +464,7 @@ def validate_config(config_data):
|
|||||||
"azure_openai",
|
"azure_openai",
|
||||||
"google",
|
"google",
|
||||||
"mistralai",
|
"mistralai",
|
||||||
|
"nvidia",
|
||||||
),
|
),
|
||||||
Optional("config"): {
|
Optional("config"): {
|
||||||
Optional("model"): str,
|
Optional("model"): str,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.88"
|
version = "0.1.89"
|
||||||
description = "Simplest open source retrieval(RAG) framework"
|
description = "Simplest open source retrieval(RAG) framework"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
Reference in New Issue
Block a user