Files
t6_mem0/mem0/llms/aws_bedrock.py
2025-03-14 17:42:48 +05:30

145 lines
4.9 KiB
Python

import json
from typing import Any, Dict, List, Optional
try:
import boto3
except ImportError:
raise ImportError(
"The 'boto3' library is required. Please install it using 'pip install boto3'."
)
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AWSBedrockLLM(LLMBase):
"""
A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AWS Bedrock LLM with the provided configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration object for the model.
"""
super().__init__(config)
if not self.config.model:
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime")
self.model_kwargs = {
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
"top_p": self.config.top_p,
}
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""
Formats a list of messages into a structured prompt for the model.
Args:
messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
"""
formatted_messages = [
f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages
]
return "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response) -> str:
"""
Extracts the generated response from the API response.
Args:
response: The raw response from the AWS Bedrock API.
Returns:
str: The generated response text.
"""
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
def _prepare_input(
self,
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Prepares the input dictionary for the specified provider's model.
Args:
provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The model identifier.
prompt (str): The input prompt.
model_kwargs (Optional[Dict[str, Any]]): Additional model parameters.
Returns:
Dict[str, Any]: The prepared input dictionary.
"""
model_kwargs = model_kwargs or {}
input_body = {"prompt": prompt, **model_kwargs}
provider_mappings = {
"meta": {"max_tokens_to_sample": "max_gen_len"},
"ai21": {"max_tokens_to_sample": "maxTokens", "top_p": "topP"},
"mistral": {"max_tokens_to_sample": "max_tokens"},
"cohere": {"max_tokens_to_sample": "max_tokens", "top_p": "p"},
}
if provider in provider_mappings:
for old_key, new_key in provider_mappings[provider].items():
if old_key in input_body:
input_body[new_key] = input_body.pop(old_key)
if provider == "cohere" and "cohere.command-r" in model:
input_body["message"] = input_body.pop("prompt")
if provider == "amazon":
input_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature"),
},
}
input_body["textGenerationConfig"] = {
k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
}
return input_body
def generate_response(self, messages: List[Dict[str, str]]) -> str:
"""
Generates a response using AWS Bedrock based on the provided messages.
Args:
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
Returns:
str: The generated response text.
"""
prompt = self._format_messages(messages)
provider = self.config.model.split(".")[0]
input_body = self._prepare_input(
provider, self.config.model, prompt, self.model_kwargs
)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
return self._parse_response(response)