Feature/bedrock embedder (#1470)

This commit is contained in:
andrewghlee
2024-08-01 13:55:28 -04:00
committed by GitHub
parent 80945df4ca
commit 563a130141
15 changed files with 390 additions and 26 deletions

View File

@@ -24,9 +24,7 @@ We use `poetry` as our package manager. You can install poetry by following the
Please DO NOT use pip or conda to install the dependencies. Instead, use poetry:
```bash
poetry install --all-extras
or
poetry install --with dev
make install_all
#activate

View File

@@ -0,0 +1,15 @@
llm:
provider: aws_bedrock
config:
model: amazon.titan-text-express-v1
deployment_name: your_llm_deployment_name
temperature: 0.5
max_tokens: 8192
top_p: 1
stream: false
embedder::
provider: aws_bedrock
config:
model: amazon.titan-embed-text-v2:0
deployment_name: you_embedding_model_deployment_name

View File

@@ -10,6 +10,7 @@ Embedchain supports several embedding models from the following providers:
<Card title="OpenAI" href="#openai"></Card>
<Card title="GoogleAI" href="#google-ai"></Card>
<Card title="Azure OpenAI" href="#azure-openai"></Card>
<Card title="AWS Bedrock" href="#aws-bedrock"></Card>
<Card title="GPT4All" href="#gpt4all"></Card>
<Card title="Hugging Face" href="#hugging-face"></Card>
<Card title="Vertex AI" href="#vertex-ai"></Card>
@@ -97,6 +98,37 @@ embedder:
For more details regarding the Google AI embedding model, please refer to the [Google AI documentation](https://ai.google.dev/tutorials/python_quickstart#use_embeddings).
</Note>
## AWS Bedrock
To use AWS Bedrock embedding function, you have to set the AWS environment variable.
<CodeGroup>
```python main.py
import os
from embedchain import App
os.environ["AWS_ACCESS_KEY_ID"] = "xxx"
os.environ["AWS_SECRET_ACCESS_KEY"] = "xxx"
os.environ["AWS_REGION"] = "us-west-2"
app = App.from_config(config_path="config.yaml")
```
```yaml config.yaml
embedder:
provider: aws_bedrock
config:
model: 'amazon.titan-embed-text-v2:0'
vector_dimension: 1024
task_type: "retrieval_document"
title: "Embeddings for Embedchain"
```
</CodeGroup>
<br/>
<Note>
For more details regarding the AWS Bedrock embedding model, please refer to the [AWS Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html).
</Note>
## Azure OpenAI
To use Azure OpenAI embedding model, you have to set some of the azure openai related environment variables as given in the code block below:

View File

@@ -0,0 +1,21 @@
from typing import Any, Dict, Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class AWSBedrockEmbedderConfig(BaseEmbedderConfig):
def __init__(
self,
model: Optional[str] = None,
deployment_name: Optional[str] = None,
vector_dimension: Optional[int] = None,
task_type: Optional[str] = None,
title: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(model, deployment_name, vector_dimension)
self.task_type = task_type or "retrieval_document"
self.title = title or "Embeddings for Embedchain"
self.model_kwargs = model_kwargs or {}

View File

@@ -0,0 +1,31 @@
from typing import Optional
try:
from langchain_aws import BedrockEmbeddings
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for AWSBedrock are not installed." "Please install with `pip install langchain_aws`"
) from None
from embedchain.config.embedder.aws_bedrock import AWSBedrockEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class AWSBedrockEmbedder(BaseEmbedder):
def __init__(self, config: Optional[AWSBedrockEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None or self.config.model == "amazon.titan-embed-text-v2:0":
self.config.model = "amazon.titan-embed-text-v2:0" # Default model if not specified
vector_dimension = self.config.vector_dimension or VectorDimensions.AMAZON_TITAN_V2.value
elif self.config.model == "amazon.titan-embed-text-v1":
vector_dimension = VectorDimensions.AMAZON_TITAN_V1.value
else:
vector_dimension = self.config.vector_dimension
embeddings = BedrockEmbeddings(model_id=self.config.model, model_kwargs=self.config.model_kwargs)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -61,6 +61,7 @@ class EmbedderFactory:
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
"aws_bedrock": "embedchain.embedder.aws_bedrock.AWSBedrockEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
@@ -70,6 +71,7 @@ class EmbedderFactory:
"clarifai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
"aws_bedrock": "embedchain.config.embedder.aws_bedrock.AWSBedrockEmbedderConfig",
}
@classmethod

View File

@@ -1,7 +1,12 @@
import os
from typing import Optional
from langchain_community.llms import Bedrock
try:
from langchain_aws import BedrockLLM
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for AWSBedrock are not installed." "Please install with `pip install langchain_aws`"
) from None
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -26,7 +31,9 @@ class AWSBedrockLlm(BaseLlm):
"Please install with `pip install boto3==1.34.20`."
) from None
self.boto_client = boto3.client("bedrock-runtime", "us-west-2" or os.environ.get("AWS_REGION"))
self.boto_client = boto3.client(
"bedrock-runtime", os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-east-1"))
)
kwargs = {
"model_id": config.model or "amazon.titan-text-express-v1",
@@ -38,11 +45,12 @@ class AWSBedrockLlm(BaseLlm):
}
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
callbacks = [StreamingStdOutCallbackHandler()]
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
else:
llm = Bedrock(**kwargs)
kwargs["streaming"] = True
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
llm = BedrockLLM(**kwargs)
return llm.invoke(prompt)

View File

@@ -5,5 +5,6 @@ class EmbeddingFunctions(Enum):
OPENAI = "OPENAI"
HUGGING_FACE = "HUGGING_FACE"
VERTEX_AI = "VERTEX_AI"
AWS_BEDROCK = "AWS_BEDROCK"
GPT4ALL = "GPT4ALL"
OLLAMA = "OLLAMA"

View File

@@ -12,3 +12,5 @@ class VectorDimensions(Enum):
NVIDIA_AI = 1024
COHERE = 384
OLLAMA = 384
AMAZON_TITAN_V1 = 1536
AMAZON_TITAN_V2 = 1024

View File

@@ -466,6 +466,7 @@ def validate_config(config_data):
"nvidia",
"ollama",
"cohere",
"aws_bedrock",
),
Optional("config"): {
Optional("model"): Optional(str),
@@ -492,6 +493,7 @@ def validate_config(config_data):
"clarifai",
"nvidia",
"ollama",
"aws_bedrock",
),
Optional("config"): {
Optional("model"): str,

File diff suppressed because one or more lines are too long

31
embedchain/poetry.lock generated
View File

@@ -2389,6 +2389,22 @@ requests = ">=2,<3"
SQLAlchemy = ">=1.4,<3"
tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0"
[[package]]
name = "langchain-aws"
version = "0.1.10"
description = "An integration package connecting AWS and LangChain"
optional = true
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langchain_aws-0.1.10-py3-none-any.whl", hash = "sha256:2cba72efaa9f0dc406d8e06a1fbaa3762678d489cbc5147cf64a7012189c161c"},
{file = "langchain_aws-0.1.10.tar.gz", hash = "sha256:7f01dacbf8345a28192cec4ef31018cc33a91de0b82122f913eec09a76d64fd5"},
]
[package.dependencies]
boto3 = ">=1.34.131,<1.35.0"
langchain-core = ">=0.2.6,<0.3"
numpy = ">=1,<2"
[[package]]
name = "langchain-cohere"
version = "0.1.9"
@@ -3239,7 +3255,6 @@ description = "Nvidia JIT LTO Library"
optional = true
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"},
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"},
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"},
]
@@ -4599,7 +4614,6 @@ files = [
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
@@ -4607,16 +4621,8 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
@@ -4633,7 +4639,6 @@ files = [
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
@@ -4641,7 +4646,6 @@ files = [
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
@@ -6565,6 +6569,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke
test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[extras]
aws = ["langchain-aws"]
elasticsearch = ["elasticsearch"]
gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"]
google = ["google-generativeai"]
@@ -6585,4 +6590,4 @@ weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<=3.13"
content-hash = "8197f676b36fed2bf02f33cd15e83c3e6640ae5ba216210af5777ab2dc139480"
content-hash = "9a57a07f9a60c51a2f8765f31d5d768f7d6ff67b63c0d419648a2b8d668527ad"

View File

@@ -138,6 +138,7 @@ sqlalchemy = "^2.0.27"
alembic = "^1.13.1"
langchain-cohere = "^0.1.4"
langchain-community = "^0.2.6"
langchain-aws = {version = "^0.1.10", optional = true}
[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
@@ -177,6 +178,7 @@ postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
mysql = ["mysql-connector-python"]
google = ["google-generativeai"]
mistralai = ["langchain-mistralai"]
aws = ["langchain-aws"]
[tool.poetry.group.docs.dependencies]

View File

@@ -0,0 +1,21 @@
from unittest.mock import patch
from embedchain.config.embedder.aws_bedrock import AWSBedrockEmbedderConfig
from embedchain.embedder.aws_bedrock import AWSBedrockEmbedder
def test_aws_bedrock_embedder_with_model():
config = AWSBedrockEmbedderConfig(
model="test-model",
model_kwargs={"param": "value"},
vector_dimension=1536,
)
with patch("embedchain.embedder.aws_bedrock.BedrockEmbeddings") as mock_embeddings:
embedder = AWSBedrockEmbedder(config=config)
assert embedder.config.model == "test-model"
assert embedder.config.model_kwargs == {"param": "value"}
assert embedder.config.vector_dimension == 1536
mock_embeddings.assert_called_once_with(
model_id="test-model",
model_kwargs={"param": "value"},
)

View File

@@ -9,7 +9,6 @@ from embedchain.llm.aws_bedrock import AWSBedrockLlm
def config(monkeypatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key_id")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_access_key")
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
config = BaseLlmConfig(
model="amazon.titan-text-express-v1",
model_kwargs={
@@ -21,7 +20,6 @@ def config(monkeypatch):
yield config
monkeypatch.delenv("AWS_ACCESS_KEY_ID")
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY")
monkeypatch.delenv("OPENAI_API_KEY")
def test_get_llm_model_answer(config, mocker):
@@ -46,7 +44,7 @@ def test_get_llm_model_answer_empty_prompt(config, mocker):
def test_get_llm_model_answer_with_streaming(config, mocker):
config.stream = True
mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.Bedrock")
mocked_bedrock_chat = mocker.patch("embedchain.llm.aws_bedrock.BedrockLLM")
llm = AWSBedrockLlm(config)
llm.get_llm_model_answer("Test query")