improvement(OSS): Fix AOSS and AWS BedRock LLM (#2697)

Co-authored-by: Prateek Chhikara <prateekchhikara24@gmail.com>
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Saket Aryan
2025-05-16 04:49:29 +05:30
committed by GitHub
parent 267e5b13ea
commit 5c67a5e6bc
14 changed files with 502 additions and 127 deletions

View File

@@ -1,4 +1,6 @@
import json
import os
import re
from typing import Any, Dict, List, Optional
try:
@@ -9,6 +11,14 @@ except ImportError:
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
def extract_provider(model: str) -> str:
for provider in PROVIDERS:
if re.search(rf"\b{re.escape(provider)}\b", model):
return provider
raise ValueError(f"Unknown provider in model: {model}")
class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
@@ -16,7 +26,27 @@ class AWSBedrockLLM(LLMBase):
if not self.config.model:
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime")
# Get AWS config from environment variables or use defaults
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
aws_region = os.environ.get("AWS_REGION", "us-west-2")
# Check if AWS config is provided in the config
if hasattr(self.config, "aws_access_key_id"):
aws_access_key = self.config.aws_access_key_id
if hasattr(self.config, "aws_secret_access_key"):
aws_secret_key = self.config.aws_secret_access_key
if hasattr(self.config, "aws_region"):
aws_region = self.config.aws_region
self.client = boto3.client(
"bedrock-runtime",
region_name=aws_region,
aws_access_key_id=aws_access_key if aws_access_key else None,
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
)
self.model_kwargs = {
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
@@ -34,13 +64,14 @@ class AWSBedrockLLM(LLMBase):
Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
"""
formatted_messages = []
for message in messages:
role = message["role"].capitalize()
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response, tools) -> str:
"""
@@ -68,8 +99,9 @@ class AWSBedrockLLM(LLMBase):
return processed_response
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
response_body = response.get("body").read().decode()
response_json = json.loads(response_body)
return response_json.get("content", [{"text": ""}])[0].get("text", "")
def _prepare_input(
self,
@@ -113,9 +145,9 @@ class AWSBedrockLLM(LLMBase):
input_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature"),
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"topP": self.model_kwargs["top_p"] or 0.9,
"temperature": self.model_kwargs["temperature"] or 0.1,
},
}
input_body["textGenerationConfig"] = {
@@ -206,15 +238,40 @@ class AWSBedrockLLM(LLMBase):
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
provider = extract_provider(self.config.model)
input_body = self._prepare_input(provider, self.config.model, prompt, model_kwargs=self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept="application/json",
contentType="application/json",
if provider == "anthropic" or provider == "deepseek":
input_body = {
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
}
],
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
"temperature": self.model_kwargs["temperature"] or 0.1,
"top_p": self.model_kwargs["top_p"] or 0.9,
"anthropic_version": "bedrock-2023-05-31",
}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
else:
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)