[Feature] Add support for AWS Bedrock LLM (#1189)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2024-01-21 14:09:08 +05:30
committed by GitHub
parent 751a3a4bd1
commit 069d265338
8 changed files with 226 additions and 8 deletions

View File

@@ -200,9 +200,10 @@ Alright, let's dive into what each key means in the yaml config above:
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables.
- `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
- `api_key` (String): The API key for the language model.
- `model_kwargs` (Dict): Keyword arguments to pass to the language model. Used for `aws_bedrock` provider, since it requires different arguments for each model.
3. `vectordb` Section:
- `provider` (String): The provider for the vector database, set to 'chroma'. You can find the full list of vector database providers in [our docs](/components/vector-databases).
- `config`:

View File

@@ -21,6 +21,7 @@ Embedchain comes with built-in support for various popular large language models
<Card title="Llama2" href="#llama2"></Card>
<Card title="Vertex AI" href="#vertex-ai"></Card>
<Card title="Mistral AI" href="#mistral-ai"></Card>
<Card title="AWS Bedrock" href="#aws-bedrock"></Card>
</CardGroup>
## OpenAI
@@ -628,10 +629,7 @@ Obtain the Mistral AI api key from their [console](https://console.mistral.ai/).
<CodeGroup>
```python main.py
import os
from embedchain import App
```python main.py
os.environ["MISTRAL_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
@@ -663,5 +661,45 @@ embedder:
```
</CodeGroup>
## AWS Bedrock
### Setup
- Before using the AWS Bedrock LLM, make sure you have the appropriate model access from [Bedrock Console](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess).
- You will also need `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to authenticate the API with AWS. You can find these in your [AWS Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users).
### Usage
<CodeGroup>
```python main.py
import os
from embedchain import App
os.environ["AWS_ACCESS_KEY_ID"] = "xxx"
os.environ["AWS_SECRET_ACCESS_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
```
```yaml config.yaml
llm:
provider: aws_bedrock
config:
model: amazon.titan-text-express-v1
# check notes below for model_kwargs
model_kwargs:
temperature: 0.5
topP: 1
maxTokenCount: 1000
```
</CodeGroup>
<br />
<Note>
The model arguments are different for each providers. Please refer to the [AWS Bedrock Documentation](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers) to find the appropriate arguments for your model.
</Note>
<br/ >
<Snippet file="missing-llm-tip.mdx" />

View File

@@ -21,6 +21,7 @@ class LlmFactory:
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
"google": "embedchain.llm.google.GoogleLlm",
"aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
}
provider_to_config_class = {

View File

@@ -0,0 +1,48 @@
from typing import Optional
from langchain.llms import Bedrock
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@register_deserializable
class AWSBedrockLlm(BaseLlm):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
def get_llm_model_answer(self, prompt) -> str:
response = self._get_answer(prompt, self.config)
return response
def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
try:
import boto3
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for AWSBedrock are not installed."
'Please install with `pip install --upgrade "embedchain[aws-bedrock]"`'
) from None
self.boto_client = boto3.client("bedrock-runtime", "us-west-2")
kwargs = {
"model_id": config.model or "amazon.titan-text-express-v1",
"client": self.boto_client,
"model_kwargs": config.model_kwargs
or {
"temperature": config.temperature,
},
}
if config.stream:
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
callbacks = [StreamingStdOutCallbackHandler()]
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
else:
llm = Bedrock(**kwargs)
return llm(prompt)

View File

@@ -406,6 +406,7 @@ def validate_config(config_data):
"llama2",
"vertexai",
"google",
"aws_bedrock",
"mistralai",
),
Optional("config"): {
@@ -423,6 +424,7 @@ def validate_config(config_data):
Optional("query_type"): str,
Optional("api_key"): str,
Optional("endpoint"): str,
Optional("model_kwargs"): dict,
},
},
Optional("vectordb"): {

74
poetry.lock generated
View File

@@ -383,6 +383,47 @@ files = [
{file = "blinker-1.6.3.tar.gz", hash = "sha256:152090d27c1c5c722ee7e48504b02d76502811ce02e1523553b4cf8c8b3d3a8d"},
]
[[package]]
name = "boto3"
version = "1.34.22"
description = "The AWS SDK for Python"
optional = true
python-versions = ">= 3.8"
files = [
{file = "boto3-1.34.22-py3-none-any.whl", hash = "sha256:5909cd1393143576265c692e908a9ae495492c04a0ffd4bae8578adc2e44729e"},
{file = "boto3-1.34.22.tar.gz", hash = "sha256:a98c0b86f6044ff8314cc2361e1ef574d674318313ab5606ccb4a6651c7a3f8c"},
]
[package.dependencies]
botocore = ">=1.34.22,<1.35.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.10.0,<0.11.0"
[package.extras]
crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.34.22"
description = "Low-level, data-driven core of boto 3."
optional = true
python-versions = ">= 3.8"
files = [
{file = "botocore-1.34.22-py3-none-any.whl", hash = "sha256:e5f7775975b9213507fbcf846a96b7a2aec2a44fc12a44585197b014a4ab0889"},
{file = "botocore-1.34.22.tar.gz", hash = "sha256:c47ba4286c576150d1b6ca6df69a87b5deff3d23bd84da8bcf8431ebac3c40ba"},
]
[package.dependencies]
jmespath = ">=0.7.1,<2.0.0"
python-dateutil = ">=2.1,<3.0.0"
urllib3 = [
{version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""},
{version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""},
]
[package.extras]
crt = ["awscrt (==0.19.19)"]
[[package]]
name = "brotli"
version = "1.1.0"
@@ -2810,6 +2851,17 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
[[package]]
name = "jmespath"
version = "1.0.1"
description = "JSON Matching Expressions"
optional = true
python-versions = ">=3.7"
files = [
{file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
{file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
]
[[package]]
name = "joblib"
version = "1.3.2"
@@ -4211,9 +4263,9 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
]
@@ -6091,6 +6143,23 @@ files = [
{file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"},
]
[[package]]
name = "s3transfer"
version = "0.10.0"
description = "An Amazon S3 Transfer Manager"
optional = true
python-versions = ">= 3.8"
files = [
{file = "s3transfer-0.10.0-py3-none-any.whl", hash = "sha256:3cdb40f5cfa6966e812209d0994f2a4709b561c88e90cf00c2696d2df4e56b2e"},
{file = "s3transfer-0.10.0.tar.gz", hash = "sha256:d0c8bbf672d5eebbe4e57945e23b972d963f07d82f661cabf678a5c88831595b"},
]
[package.dependencies]
botocore = ">=1.33.2,<2.0a.0"
[package.extras]
crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
[[package]]
name = "safetensors"
version = "0.4.0"
@@ -8213,6 +8282,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
[extras]
aws-bedrock = ["boto3"]
cohere = ["cohere"]
dataloaders = ["docx2txt", "duckduckgo-search", "pytube", "sentence-transformers", "unstructured", "youtube-transcript-api"]
discord = ["discord"]
@@ -8246,4 +8316,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c"
content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7"

View File

@@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true }
google-auth = { version = "^2.25.2", optional = true }
google-auth-httplib2 = { version = "^0.2.0", optional = true }
google-api-core = { version = "^2.15.0", optional = true }
boto3 = { version = "^1.34.20", optional = true }
langchain-mistralai = { version = "^0.0.3", optional = true }
[tool.poetry.group.dev.dependencies]
@@ -215,6 +216,7 @@ rss_feed = [
google = ["google-generativeai"]
modal = ["modal"]
dropbox = ["dropbox"]
aws_bedrock = ["boto3"]
mistralai = ["langchain-mistralai"]
[tool.poetry.group.docs.dependencies]

View File

@@ -0,0 +1,56 @@
import pytest
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.llm.aws_bedrock import AWSBedrockLlm
@pytest.fixture
def config(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
config = BaseLlmConfig(
model="amazon.titan-text-express-v1",
model_kwargs={
"temperature": 0.5,
"topP": 1,
"maxTokenCount": 1000,
},
)
yield config
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
monkeypatch.delenv("OPENAI_API_KEY")
def test_get_llm_model_answer(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
llm = AWSBedrockLlm(config)
answer = llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("Test query", config)
def test_get_llm_model_answer_empty_prompt(config, mocker):
mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
llm = AWSBedrockLlm(config)
answer = llm.get_llm_model_answer("")
assert answer == "Test answer"
mocked_get_answer.assert_called_once_with("", config)
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
llm = AWSBedrockLlm(config)
llm.get_llm_model_answer("Test query")
mocked_bedrock_chat.assert_called_once()
callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)