Feature/bedrock embedder (#1470)

This commit is contained in:
andrewghlee
2024-08-01 13:55:28 -04:00
committed by GitHub
parent 80945df4ca
commit 563a130141
15 changed files with 390 additions and 26 deletions

View File

@@ -0,0 +1,21 @@
from typing import Any, Dict, Optional
from embedchain.config.embedder.base import BaseEmbedderConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class AWSBedrockEmbedderConfig(BaseEmbedderConfig):
def __init__(
self,
model: Optional[str] = None,
deployment_name: Optional[str] = None,
vector_dimension: Optional[int] = None,
task_type: Optional[str] = None,
title: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(model, deployment_name, vector_dimension)
self.task_type = task_type or "retrieval_document"
self.title = title or "Embeddings for Embedchain"
self.model_kwargs = model_kwargs or {}

View File

@@ -0,0 +1,31 @@
from typing import Optional
try:
from langchain_aws import BedrockEmbeddings
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for AWSBedrock are not installed." "Please install with `pip install langchain_aws`"
) from None
from embedchain.config.embedder.aws_bedrock import AWSBedrockEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.models import VectorDimensions
class AWSBedrockEmbedder(BaseEmbedder):
def __init__(self, config: Optional[AWSBedrockEmbedderConfig] = None):
super().__init__(config)
if self.config.model is None or self.config.model == "amazon.titan-embed-text-v2:0":
self.config.model = "amazon.titan-embed-text-v2:0" # Default model if not specified
vector_dimension = self.config.vector_dimension or VectorDimensions.AMAZON_TITAN_V2.value
elif self.config.model == "amazon.titan-embed-text-v1":
vector_dimension = VectorDimensions.AMAZON_TITAN_V1.value
else:
vector_dimension = self.config.vector_dimension
embeddings = BedrockEmbeddings(model_id=self.config.model, model_kwargs=self.config.model_kwargs)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)
self.set_vector_dimension(vector_dimension=vector_dimension)

View File

@@ -61,6 +61,7 @@ class EmbedderFactory:
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
"aws_bedrock": "embedchain.embedder.aws_bedrock.AWSBedrockEmbedder",
}
provider_to_config_class = {
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
@@ -70,6 +71,7 @@ class EmbedderFactory:
"clarifai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
"aws_bedrock": "embedchain.config.embedder.aws_bedrock.AWSBedrockEmbedderConfig",
}
@classmethod

View File

@@ -1,7 +1,12 @@
import os
from typing import Optional
from langchain_community.llms import Bedrock
try:
from langchain_aws import BedrockLLM
except ModuleNotFoundError:
raise ModuleNotFoundError(
"The required dependencies for AWSBedrock are not installed." "Please install with `pip install langchain_aws`"
) from None
from embedchain.config import BaseLlmConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -26,7 +31,9 @@ class AWSBedrockLlm(BaseLlm):
"Please install with `pip install boto3==1.34.20`."
) from None
self.boto_client = boto3.client("bedrock-runtime", "us-west-2" or os.environ.get("AWS_REGION"))
self.boto_client = boto3.client(
"bedrock-runtime", os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-east-1"))
)
kwargs = {
"model_id": config.model or "amazon.titan-text-express-v1",
@@ -38,11 +45,12 @@ class AWSBedrockLlm(BaseLlm):
}
if config.stream:
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
callbacks = [StreamingStdOutCallbackHandler()]
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
else:
llm = Bedrock(**kwargs)
kwargs["streaming"] = True
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
llm = BedrockLLM(**kwargs)
return llm.invoke(prompt)

View File

@@ -5,5 +5,6 @@ class EmbeddingFunctions(Enum):
OPENAI = "OPENAI"
HUGGING_FACE = "HUGGING_FACE"
VERTEX_AI = "VERTEX_AI"
AWS_BEDROCK = "AWS_BEDROCK"
GPT4ALL = "GPT4ALL"
OLLAMA = "OLLAMA"

View File

@@ -12,3 +12,5 @@ class VectorDimensions(Enum):
NVIDIA_AI = 1024
COHERE = 384
OLLAMA = 384
AMAZON_TITAN_V1 = 1536
AMAZON_TITAN_V2 = 1024

View File

@@ -466,6 +466,7 @@ def validate_config(config_data):
"nvidia",
"ollama",
"cohere",
"aws_bedrock",
),
Optional("config"): {
Optional("model"): Optional(str),
@@ -492,6 +493,7 @@ def validate_config(config_data):
"clarifai",
"nvidia",
"ollama",
"aws_bedrock",
),
Optional("config"): {
Optional("model"): str,