From 069d2653380bbfb1a1a7675962f796318151ed80 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Sun, 21 Jan 2024 14:09:08 +0530 Subject: [PATCH] [Feature] Add support for AWS Bedrock LLM (#1189) Co-authored-by: Deven Patel --- docs/api-reference/advanced/configuration.mdx | 3 +- docs/components/llms.mdx | 48 ++++++++++-- embedchain/factory.py | 1 + embedchain/llm/aws_bedrock.py | 48 ++++++++++++ embedchain/utils/misc.py | 2 + poetry.lock | 74 ++++++++++++++++++- pyproject.toml | 2 + tests/llm/test_aws_bedrock.py | 56 ++++++++++++++ 8 files changed, 226 insertions(+), 8 deletions(-) create mode 100644 embedchain/llm/aws_bedrock.py create mode 100644 tests/llm/test_aws_bedrock.py diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index 365a223e..6c0cae3d 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -200,9 +200,10 @@ Alright, let's dive into what each key means in the yaml config above: - `stream` (Boolean): Controls if the response is streamed back to the user (set to false). - `prompt` (String): A prompt for the model to follow when generating responses, requires `$context` and `$query` variables. - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare. - - `stream` (Boolean): Controls if the response is streamed back to the user (set to false). + - `stream` (Boolean): Controls if the response is streamed back to the user (set to false). - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1 - `api_key` (String): The API key for the language model. + - `model_kwargs` (Dict): Keyword arguments to pass to the language model. Used for `aws_bedrock` provider, since it requires different arguments for each model. 3. `vectordb` Section: - `provider` (String): The provider for the vector database, set to 'chroma'. You can find the full list of vector database providers in [our docs](/components/vector-databases). - `config`: diff --git a/docs/components/llms.mdx b/docs/components/llms.mdx index b9b4b138..9adbc6ab 100644 --- a/docs/components/llms.mdx +++ b/docs/components/llms.mdx @@ -21,6 +21,7 @@ Embedchain comes with built-in support for various popular large language models + ## OpenAI @@ -627,11 +628,8 @@ llm: Obtain the Mistral AI api key from their [console](https://console.mistral.ai/). - -```python main.py -import os -from embedchain import App - + + ```python main.py os.environ["MISTRAL_API_KEY"] = "xxx" app = App.from_config(config_path="config.yaml") @@ -663,5 +661,45 @@ embedder: ``` + +## AWS Bedrock + +### Setup +- Before using the AWS Bedrock LLM, make sure you have the appropriate model access from [Bedrock Console](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/modelaccess). +- You will also need `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to authenticate the API with AWS. You can find these in your [AWS Console](https://us-east-1.console.aws.amazon.com/iam/home?region=us-east-1#/users). + + +### Usage + + + +```python main.py +import os +from embedchain import App + +os.environ["AWS_ACCESS_KEY_ID"] = "xxx" +os.environ["AWS_SECRET_ACCESS_KEY"] = "xxx" + +app = App.from_config(config_path="config.yaml") +``` + +```yaml config.yaml +llm: + provider: aws_bedrock + config: + model: amazon.titan-text-express-v1 + # check notes below for model_kwargs + model_kwargs: + temperature: 0.5 + topP: 1 + maxTokenCount: 1000 +``` + + +
+ + The model arguments are different for each providers. Please refer to the [AWS Bedrock Documentation](https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers) to find the appropriate arguments for your model. + +
diff --git a/embedchain/factory.py b/embedchain/factory.py index 04b20c56..59570d3a 100644 --- a/embedchain/factory.py +++ b/embedchain/factory.py @@ -21,6 +21,7 @@ class LlmFactory: "openai": "embedchain.llm.openai.OpenAILlm", "vertexai": "embedchain.llm.vertex_ai.VertexAILlm", "google": "embedchain.llm.google.GoogleLlm", + "aws_bedrock": "embedchain.llm.aws_bedrock.AWSBedrockLlm", "mistralai": "embedchain.llm.mistralai.MistralAILlm", } provider_to_config_class = { diff --git a/embedchain/llm/aws_bedrock.py b/embedchain/llm/aws_bedrock.py new file mode 100644 index 00000000..e1f2521a --- /dev/null +++ b/embedchain/llm/aws_bedrock.py @@ -0,0 +1,48 @@ +from typing import Optional + +from langchain.llms import Bedrock + +from embedchain.config import BaseLlmConfig +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.llm.base import BaseLlm + + +@register_deserializable +class AWSBedrockLlm(BaseLlm): + def __init__(self, config: Optional[BaseLlmConfig] = None): + super().__init__(config) + + def get_llm_model_answer(self, prompt) -> str: + response = self._get_answer(prompt, self.config) + return response + + def _get_answer(self, prompt: str, config: BaseLlmConfig) -> str: + try: + import boto3 + except ModuleNotFoundError: + raise ModuleNotFoundError( + "The required dependencies for AWSBedrock are not installed." + 'Please install with `pip install --upgrade "embedchain[aws-bedrock]"`' + ) from None + + self.boto_client = boto3.client("bedrock-runtime", "us-west-2") + + kwargs = { + "model_id": config.model or "amazon.titan-text-express-v1", + "client": self.boto_client, + "model_kwargs": config.model_kwargs + or { + "temperature": config.temperature, + }, + } + + if config.stream: + from langchain.callbacks.streaming_stdout import \ + StreamingStdOutCallbackHandler + + callbacks = [StreamingStdOutCallbackHandler()] + llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks) + else: + llm = Bedrock(**kwargs) + + return llm(prompt) diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 6cc13ade..cb7ff090 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -406,6 +406,7 @@ def validate_config(config_data): "llama2", "vertexai", "google", + "aws_bedrock", "mistralai", ), Optional("config"): { @@ -423,6 +424,7 @@ def validate_config(config_data): Optional("query_type"): str, Optional("api_key"): str, Optional("endpoint"): str, + Optional("model_kwargs"): dict, }, }, Optional("vectordb"): { diff --git a/poetry.lock b/poetry.lock index b2599536..8935ff69 100644 --- a/poetry.lock +++ b/poetry.lock @@ -383,6 +383,47 @@ files = [ {file = "blinker-1.6.3.tar.gz", hash = "sha256:152090d27c1c5c722ee7e48504b02d76502811ce02e1523553b4cf8c8b3d3a8d"}, ] +[[package]] +name = "boto3" +version = "1.34.22" +description = "The AWS SDK for Python" +optional = true +python-versions = ">= 3.8" +files = [ + {file = "boto3-1.34.22-py3-none-any.whl", hash = "sha256:5909cd1393143576265c692e908a9ae495492c04a0ffd4bae8578adc2e44729e"}, + {file = "boto3-1.34.22.tar.gz", hash = "sha256:a98c0b86f6044ff8314cc2361e1ef574d674318313ab5606ccb4a6651c7a3f8c"}, +] + +[package.dependencies] +botocore = ">=1.34.22,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.22" +description = "Low-level, data-driven core of boto 3." +optional = true +python-versions = ">= 3.8" +files = [ + {file = "botocore-1.34.22-py3-none-any.whl", hash = "sha256:e5f7775975b9213507fbcf846a96b7a2aec2a44fc12a44585197b014a4ab0889"}, + {file = "botocore-1.34.22.tar.gz", hash = "sha256:c47ba4286c576150d1b6ca6df69a87b5deff3d23bd84da8bcf8431ebac3c40ba"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, + {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.19.19)"] + [[package]] name = "brotli" version = "1.1.0" @@ -2810,6 +2851,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = true +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "joblib" version = "1.3.2" @@ -4211,9 +4263,9 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, ] @@ -6091,6 +6143,23 @@ files = [ {file = "ruff-0.1.11.tar.gz", hash = "sha256:f9d4d88cb6eeb4dfe20f9f0519bd2eaba8119bde87c3d5065c541dbae2b5a2cb"}, ] +[[package]] +name = "s3transfer" +version = "0.10.0" +description = "An Amazon S3 Transfer Manager" +optional = true +python-versions = ">= 3.8" +files = [ + {file = "s3transfer-0.10.0-py3-none-any.whl", hash = "sha256:3cdb40f5cfa6966e812209d0994f2a4709b561c88e90cf00c2696d2df4e56b2e"}, + {file = "s3transfer-0.10.0.tar.gz", hash = "sha256:d0c8bbf672d5eebbe4e57945e23b972d963f07d82f661cabf678a5c88831595b"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "safetensors" version = "0.4.0" @@ -8213,6 +8282,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] +aws-bedrock = ["boto3"] cohere = ["cohere"] dataloaders = ["docx2txt", "duckduckgo-search", "pytube", "sentence-transformers", "unstructured", "youtube-transcript-api"] discord = ["discord"] @@ -8246,4 +8316,4 @@ youtube = ["youtube-transcript-api", "yt_dlp"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "cb0da55af7c61300bb321770ed319c900b6b3ba3865421d63eb9120beb73d06c" +content-hash = "bbcf32e87c0784d031fb6cf9bd89655375839da0660b8feb2026ffdd971623d7" diff --git a/pyproject.toml b/pyproject.toml index 6613f080..8b81e5ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ google-auth-oauthlib = { version = "^1.2.0", optional = true } google-auth = { version = "^2.25.2", optional = true } google-auth-httplib2 = { version = "^0.2.0", optional = true } google-api-core = { version = "^2.15.0", optional = true } +boto3 = { version = "^1.34.20", optional = true } langchain-mistralai = { version = "^0.0.3", optional = true } [tool.poetry.group.dev.dependencies] @@ -215,6 +216,7 @@ rss_feed = [ google = ["google-generativeai"] modal = ["modal"] dropbox = ["dropbox"] +aws_bedrock = ["boto3"] mistralai = ["langchain-mistralai"] [tool.poetry.group.docs.dependencies] diff --git a/tests/llm/test_aws_bedrock.py b/tests/llm/test_aws_bedrock.py new file mode 100644 index 00000000..42cde396 --- /dev/null +++ b/tests/llm/test_aws_bedrock.py @@ -0,0 +1,56 @@ +import pytest +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + +from embedchain.config import BaseLlmConfig +from embedchain.llm.aws_bedrock import AWSBedrockLlm + + +@pytest.fixture +def config(monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key") + monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") + config = BaseLlmConfig( + model="amazon.titan-text-express-v1", + model_kwargs={ + "temperature": 0.5, + "topP": 1, + "maxTokenCount": 1000, + }, + ) + yield config + monkeypatch.delenv("AWS_ACCESS_KEY_ID") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY") + monkeypatch.delenv("OPENAI_API_KEY") + + +def test_get_llm_model_answer(config, mocker): + mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer") + + llm = AWSBedrockLlm(config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("Test query", config) + + +def test_get_llm_model_answer_empty_prompt(config, mocker): + mocked_get_answer = mocker.patch("embedchain.llm.aws_bedrock.AWSBedrockLlm._get_answer", return_value="Test answer") + + llm = AWSBedrockLlm(config) + answer = llm.get_llm_model_answer("") + + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("", config) + + +def test_get_llm_model_answer_with_streaming(config, mocker): + config.stream = True + mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock") + + llm = AWSBedrockLlm(config) + llm.get_llm_model_answer("Test query") + + mocked_bedrock_chat.assert_called_once() + callbacks = [callback[1]["callbacks"] for callback in mocked_bedrock_chat.call_args_list] + assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)