diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx
index 365a223e..6c0cae3d 100644
--- a/docs/api-reference/advanced/configuration.mdx
+++ b/docs/api-reference/advanced/configuration.mdx
@@ -200,9 +200,10 @@ Alright, let's dive into what each key means in the yaml config above:
- `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables.
- `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
- - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
+ - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
- `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
- `api_key` (String): The API key for the language model.
+ - `model_kwargs` (Dict): Keyword arguments to pass to the language model. Used for `aws_bedrock` provider, since it requires different arguments for each model.
3. `vectordb` Section:
- `provider` (String): The provider for the vector database, set to 'chroma'. You can find the full list of vector database providers in [our docs](/components/vector-databases).
- `config`:
diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx
index b9b4b138..9adbc6ab 100644
--- a/docs/components/llms.mdx
+++ b/docs/components/llms.mdx
@@ -21,6 +21,7 @@ Embedchain comes with built-in support for various popular large language models
+
## OpenAI
@@ -627,11 +628,8 @@ llm:
Obtain the Mistral AI api key from their [console](https://console.mistral.ai/).
-
-```python main.py
-import os
-from embedchain import App
-
+
+ ```python main.py
os.environ["MISTRAL_API_KEY"] = "xxx"
app = App.from_config(config_path="config.yaml")
@@ -663,5 +661,45 @@ embedder:
```
+
+## AWS Bedrock
+
+### Setup
+- Before using the AWS Bedrock LLM, make sure you have the appropriate model access from [Bedrock Console](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess).
+- You will also need `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to authenticate the API with AWS. You can find these in your [AWS Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users).
+
+
+### Usage
+
+
+
+```python main.py
+import os
+from embedchain import App
+
+os.environ["AWS_ACCESS_KEY_ID"] = "xxx"
+os.environ["AWS_SECRET_ACCESS_KEY"] = "xxx"
+
+app = App.from_config(config_path="config.yaml")
+```
+
+```yaml config.yaml
+llm:
+ provider: aws_bedrock
+ config:
+ model: amazon.titan-text-express-v1
+ # check notes below for model_kwargs
+ model_kwargs:
+ temperature: 0.5
+ topP: 1
+ maxTokenCount: 1000
+```
+
+
+
+
+ The model arguments are different for each providers. Please refer to the [AWS Bedrock Documentation](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers) to find the appropriate arguments for your model.
+
+
diff --git a/embedchain/factory.py b/embedchain/factory.py
index 04b20c56..59570d3a 100644
--- a/embedchain/factory.py
+++ b/embedchain/factory.py
@@ -21,6 +21,7 @@ class LlmFactory:
"openai": "embedchain.llm.openai.OpenAILlm",
"vertexai": "embedchain.llm.vertex_ai.VertexAILlm",
"google": "embedchain.llm.google.GoogleLlm",
+ "aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm",
"mistralai": "embedchain.llm.mistralai.MistralAILlm",
}
provider_to_config_class = {
diff --git a/embedchain/llm/aws_bedrock.py b/embedchain/llm/aws_bedrock.py
new file mode 100644
index 00000000..e1f2521a
--- /dev/null
+++ b/embedchain/llm/aws_bedrock.py
@@ -0,0 +1,48 @@
+from typing import Optional
+
+from langchain.llms import Bedrock
+
+from embedchain.config import BaseLlmConfig
+from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.llm.base import BaseLlm
+
+
+@register_deserializable
+class AWSBedrockLlm(BaseLlm):
+ def __init__(self, config: Optional[BaseLlmConfig] = None):
+ super().__init__(config)
+
+ def get_llm_model_answer(self, prompt) -> str:
+ response = self._get_answer(prompt, self.config)
+ return response
+
+ def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str:
+ try:
+ import boto3
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "The required dependencies for AWSBedrock are not installed."
+ 'Please install with `pip install --upgrade "embedchain[aws-bedrock]"`'
+ ) from None
+
+ self.boto_client = boto3.client("bedrock-runtime", "us-west-2")
+
+ kwargs = {
+ "model_id": config.model or "amazon.titan-text-express-v1",
+ "client": self.boto_client,
+ "model_kwargs": config.model_kwargs
+ or {
+ "temperature": config.temperature,
+ },
+ }
+
+ if config.stream:
+ from langchain.callbacks.streaming_stdout import \
+ StreamingStdOutCallbackHandler
+
+ callbacks = [StreamingStdOutCallbackHandler()]
+ llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
+ else:
+ llm = Bedrock(**kwargs)
+
+ return llm(prompt)
diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py
index 6cc13ade..cb7ff090 100644
--- a/embedchain/utils/misc.py
+++ b/embedchain/utils/misc.py
@@ -406,6 +406,7 @@ def validate_config(config_data):
"llama2",
"vertexai",
"google",
+ "aws_bedrock",
"mistralai",
),
Optional("config"): {
@@ -423,6 +424,7 @@ def validate_config(config_data):
Optional("query_type"): str,
Optional("api_key"): str,
Optional("endpoint"): str,
+ Optional("model_kwargs"): dict,
},
},
Optional("vectordb"): {
diff --git a/poetry.lock b/poetry.lock
index b2599536..8935ff69 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -383,6 +383,47 @@ files = [
{file = "blinker-1.6.3.tar.gz", hash = "sha256:152090d27c1c5c722ee7e48504b02d76502811ce02e1523553b4cf8c8b3d3a8d"},
]
+[[package]]
+name = "boto3"
+version = "1.34.22"
+description = "The AWS SDK for Python"
+optional = true
+python-versions = ">= 3.8"
+files = [
+ {file = "boto3-1.34.22-py3-none-any.whl", hash = "sha256:5909cd1393143576265c692e908a9ae495492c04a0ffd4bae8578adc2e44729e"},
+ {file = "boto3-1.34.22.tar.gz", hash = "sha256:a98c0b86f6044ff8314cc2361e1ef574d674318313ab5606ccb4a6651c7a3f8c"},
+]
+
+[package.dependencies]
+botocore = ">=1.34.22,<1.35.0"
+jmespath = ">=0.7.1,<2.0.0"
+s3transfer = ">=0.10.0,<0.11.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
+
+[[package]]
+name = "botocore"
+version = "1.34.22"
+description = "Low-level, data-driven core of boto 3."
+optional = true
+python-versions = ">= 3.8"
+files = [
+ {file = "botocore-1.34.22-py3-none-any.whl", hash = "sha256:e5f7775975b9213507fbcf846a96b7a2aec2a44fc12a44585197b014a4ab0889"},
+ {file = "botocore-1.34.22.tar.gz", hash = "sha256:c47ba4286c576150d1b6ca6df69a87b5deff3d23bd84da8bcf8431ebac3c40ba"},
+]
+
+[package.dependencies]
+jmespath = ">=0.7.1,<2.0.0"
+python-dateutil = ">=2.1,<3.0.0"
+urllib3 = [
+ {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""},
+ {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""},
+]
+
+[package.extras]
+crt = ["awscrt (==0.19.19)"]
+
[[package]]
name = "brotli"
version = "1.1.0"
@@ -2810,6 +2851,17 @@ MarkupSafe = ">=2.0"
[package.extras]
i18n = ["Babel (>=2.7)"]
+[[package]]
+name = "jmespath"
+version = "1.0.1"
+description = "JSON Matching Expressions"
+optional = true
+python-versions = ">=3.7"
+files = [
+ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"},
+ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"},
+]
+
[[package]]
name = "joblib"
version = "1.3.2"
@@ -4211,9 +4263,9 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
+ {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
- {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
]
@@ -6091,6 +6143,23 @@ files = [
{file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"},
]
+[[package]]
+name = "s3transfer"
+version = "0.10.0"
+description = "An Amazon S3 Transfer Manager"
+optional = true
+python-versions = ">= 3.8"
+files = [
+ {file = "s3transfer-0.10.0-py3-none-any.whl", hash = "sha256:3cdb40f5cfa6966e812209d0994f2a4709b561c88e90cf00c2696d2df4e56b2e"},
+ {file = "s3transfer-0.10.0.tar.gz", hash = "sha256:d0c8bbf672d5eebbe4e57945e23b972d963f07d82f661cabf678a5c88831595b"},
+]
+
+[package.dependencies]
+botocore = ">=1.33.2,<2.0a.0"
+
+[package.extras]
+crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"]
+
[[package]]
name = "safetensors"
version = "0.4.0"
@@ -8213,6 +8282,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
[extras]
+aws-bedrock = ["boto3"]
cohere = ["cohere"]
dataloaders = ["docx2txt", "duckduckgo-search", "pytube", "sentence-transformers", "unstructured", "youtube-transcript-api"]
discord = ["discord"]
@@ -8246,4 +8316,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.12"
-content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c"
+content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7"
diff --git a/pyproject.toml b/pyproject.toml
index 6613f080..8b81e5ac 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true }
google-auth = { version = "^2.25.2", optional = true }
google-auth-httplib2 = { version = "^0.2.0", optional = true }
google-api-core = { version = "^2.15.0", optional = true }
+boto3 = { version = "^1.34.20", optional = true }
langchain-mistralai = { version = "^0.0.3", optional = true }
[tool.poetry.group.dev.dependencies]
@@ -215,6 +216,7 @@ rss_feed = [
google = ["google-generativeai"]
modal = ["modal"]
dropbox = ["dropbox"]
+aws_bedrock = ["boto3"]
mistralai = ["langchain-mistralai"]
[tool.poetry.group.docs.dependencies]
diff --git a/tests/llm/test_aws_bedrock.py b/tests/llm/test_aws_bedrock.py
new file mode 100644
index 00000000..42cde396
--- /dev/null
+++ b/tests/llm/test_aws_bedrock.py
@@ -0,0 +1,56 @@
+import pytest
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+
+from embedchain.config import BaseLlmConfig
+from embedchain.llm.aws_bedrock import AWSBedrockLlm
+
+
+@pytest.fixture
+def config(monkeypatch):
+ monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
+ monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
+ monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
+ config = BaseLlmConfig(
+ model="amazon.titan-text-express-v1",
+ model_kwargs={
+ "temperature": 0.5,
+ "topP": 1,
+ "maxTokenCount": 1000,
+ },
+ )
+ yield config
+ monkeypatch.delenv("AWS_ACCESS_KEY_ID")
+ monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
+ monkeypatch.delenv("OPENAI_API_KEY")
+
+
+def test_get_llm_model_answer(config, mocker):
+ mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
+
+ llm = AWSBedrockLlm(config)
+ answer = llm.get_llm_model_answer("Test query")
+
+ assert answer == "Test answer"
+ mocked_get_answer.assert_called_once_with("Test query", config)
+
+
+def test_get_llm_model_answer_empty_prompt(config, mocker):
+ mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer")
+
+ llm = AWSBedrockLlm(config)
+ answer = llm.get_llm_model_answer("")
+
+ assert answer == "Test answer"
+ mocked_get_answer.assert_called_once_with("", config)
+
+
+def test_get_llm_model_answer_with_streaming(config, mocker):
+ config.stream = True
+ mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
+
+ llm = AWSBedrockLlm(config)
+ llm.get_llm_model_answer("Test query")
+
+ mocked_bedrock_chat.assert_called_once()
+ callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list]
+ assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)