Formatting (#2750)
This commit is contained in:
@@ -20,18 +20,19 @@ def extract_provider(model: str) -> str:
|
||||
return provider
|
||||
raise ValueError(f"Unknown provider in model: {model}")
|
||||
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
|
||||
|
||||
# Get AWS config from environment variables or use defaults
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||
|
||||
|
||||
# Check if AWS config is provided in the config
|
||||
if hasattr(self.config, "aws_access_key_id"):
|
||||
aws_access_key = self.config.aws_access_key_id
|
||||
@@ -39,14 +40,14 @@ class AWSBedrockLLM(LLMBase):
|
||||
aws_secret_key = self.config.aws_secret_access_key
|
||||
if hasattr(self.config, "aws_region"):
|
||||
aws_region = self.config.aws_region
|
||||
|
||||
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||
)
|
||||
|
||||
|
||||
self.model_kwargs = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens_to_sample": self.config.max_tokens,
|
||||
@@ -145,7 +146,9 @@ class AWSBedrockLLM(LLMBase):
|
||||
input_body = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"]
|
||||
or self.model_kwargs["max_tokens"]
|
||||
or 5000,
|
||||
"topP": self.model_kwargs["top_p"] or 0.9,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
},
|
||||
@@ -243,22 +246,15 @@ class AWSBedrockLLM(LLMBase):
|
||||
body = json.dumps(input_body)
|
||||
|
||||
if provider == "anthropic" or provider == "deepseek":
|
||||
|
||||
input_body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}]
|
||||
}
|
||||
],
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
"top_p": self.model_kwargs["top_p"] or 0.9,
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
}
|
||||
|
||||
|
||||
body = json.dumps(input_body)
|
||||
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
@@ -272,6 +268,6 @@ class AWSBedrockLLM(LLMBase):
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
Reference in New Issue
Block a user