Feature/bedrock embedder (#1470)
This commit is contained in:
21
embedchain/embedchain/config/embedder/aws_bedrock.py
Normal file
21
embedchain/embedchain/config/embedder/aws_bedrock.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AWSBedrockEmbedderConfig(BaseEmbedderConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
deployment_name: Optional[str] = None,
|
||||
vector_dimension: Optional[int] = None,
|
||||
task_type: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(model, deployment_name, vector_dimension)
|
||||
self.task_type = task_type or "retrieval_document"
|
||||
self.title = title or "Embeddings for Embedchain"
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
31
embedchain/embedchain/embedder/aws_bedrock.py
Normal file
31
embedchain/embedchain/embedder/aws_bedrock.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for AWSBedrock are not installed." "Please install with `pip install langchain_aws`"
|
||||
) from None
|
||||
|
||||
from embedchain.config.embedder.aws_bedrock import AWSBedrockEmbedderConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.models import VectorDimensions
|
||||
|
||||
|
||||
class AWSBedrockEmbedder(BaseEmbedder):
|
||||
def __init__(self, config: Optional[AWSBedrockEmbedderConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if self.config.model is None or self.config.model == "amazon.titan-embed-text-v2:0":
|
||||
self.config.model = "amazon.titan-embed-text-v2:0" # Default model if not specified
|
||||
vector_dimension = self.config.vector_dimension or VectorDimensions.AMAZON_TITAN_V2.value
|
||||
elif self.config.model == "amazon.titan-embed-text-v1":
|
||||
vector_dimension = VectorDimensions.AMAZON_TITAN_V1.value
|
||||
else:
|
||||
vector_dimension = self.config.vector_dimension
|
||||
|
||||
embeddings = BedrockEmbeddings(model_id=self.config.model, model_kwargs=self.config.model_kwargs)
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
self.set_vector_dimension(vector_dimension=vector_dimension)
|
||||
@@ -61,6 +61,7 @@ class EmbedderFactory:
|
||||
"nvidia": "embedchain.embedder.nvidia.NvidiaEmbedder",
|
||||
"cohere": "embedchain.embedder.cohere.CohereEmbedder",
|
||||
"ollama": "embedchain.embedder.ollama.OllamaEmbedder",
|
||||
"aws_bedrock": "embedchain.embedder.aws_bedrock.AWSBedrockEmbedder",
|
||||
}
|
||||
provider_to_config_class = {
|
||||
"azure_openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
@@ -70,6 +71,7 @@ class EmbedderFactory:
|
||||
"clarifai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
"openai": "embedchain.config.embedder.base.BaseEmbedderConfig",
|
||||
"ollama": "embedchain.config.embedder.ollama.OllamaEmbedderConfig",
|
||||
"aws_bedrock": "embedchain.config.embedder.aws_bedrock.AWSBedrockEmbedderConfig",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.llms import Bedrock
|
||||
try:
|
||||
from langchain_aws import BedrockLLM
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"The required dependencies for AWSBedrock are not installed." "Please install with `pip install langchain_aws`"
|
||||
) from None
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
@@ -26,7 +31,9 @@ class AWSBedrockLlm(BaseLlm):
|
||||
"Please install with `pip install boto3==1.34.20`."
|
||||
) from None
|
||||
|
||||
self.boto_client = boto3.client("bedrock-runtime", "us-west-2" or os.environ.get("AWS_REGION"))
|
||||
self.boto_client = boto3.client(
|
||||
"bedrock-runtime", os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION", "us-east-1"))
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model_id": config.model or "amazon.titan-text-express-v1",
|
||||
@@ -38,11 +45,12 @@ class AWSBedrockLlm(BaseLlm):
|
||||
}
|
||||
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
callbacks = [StreamingStdOutCallbackHandler()]
|
||||
llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)
|
||||
else:
|
||||
llm = Bedrock(**kwargs)
|
||||
kwargs["streaming"] = True
|
||||
kwargs["callbacks"] = [StreamingStdOutCallbackHandler()]
|
||||
|
||||
llm = BedrockLLM(**kwargs)
|
||||
|
||||
return llm.invoke(prompt)
|
||||
|
||||
@@ -5,5 +5,6 @@ class EmbeddingFunctions(Enum):
|
||||
OPENAI = "OPENAI"
|
||||
HUGGING_FACE = "HUGGING_FACE"
|
||||
VERTEX_AI = "VERTEX_AI"
|
||||
AWS_BEDROCK = "AWS_BEDROCK"
|
||||
GPT4ALL = "GPT4ALL"
|
||||
OLLAMA = "OLLAMA"
|
||||
|
||||
@@ -12,3 +12,5 @@ class VectorDimensions(Enum):
|
||||
NVIDIA_AI = 1024
|
||||
COHERE = 384
|
||||
OLLAMA = 384
|
||||
AMAZON_TITAN_V1 = 1536
|
||||
AMAZON_TITAN_V2 = 1024
|
||||
|
||||
@@ -466,6 +466,7 @@ def validate_config(config_data):
|
||||
"nvidia",
|
||||
"ollama",
|
||||
"cohere",
|
||||
"aws_bedrock",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): Optional(str),
|
||||
@@ -492,6 +493,7 @@ def validate_config(config_data):
|
||||
"clarifai",
|
||||
"nvidia",
|
||||
"ollama",
|
||||
"aws_bedrock",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
|
||||
Reference in New Issue
Block a user