improvement(OSS): Fix AOSS and AWS BedRock LLM (#2697)
Co-authored-by: Prateek Chhikara <prateekchhikara24@gmail.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
@@ -9,6 +11,14 @@ except ImportError:
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
PROVIDERS = ["ai21", "amazon", "anthropic", "cohere", "meta", "mistral", "stability", "writer"]
|
||||
|
||||
|
||||
def extract_provider(model: str) -> str:
|
||||
for provider in PROVIDERS:
|
||||
if re.search(rf"\b{re.escape(provider)}\b", model):
|
||||
return provider
|
||||
raise ValueError(f"Unknown provider in model: {model}")
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
@@ -16,7 +26,27 @@ class AWSBedrockLLM(LLMBase):
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.client = boto3.client("bedrock-runtime")
|
||||
|
||||
# Get AWS config from environment variables or use defaults
|
||||
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", "")
|
||||
aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "")
|
||||
aws_region = os.environ.get("AWS_REGION", "us-west-2")
|
||||
|
||||
# Check if AWS config is provided in the config
|
||||
if hasattr(self.config, "aws_access_key_id"):
|
||||
aws_access_key = self.config.aws_access_key_id
|
||||
if hasattr(self.config, "aws_secret_access_key"):
|
||||
aws_secret_key = self.config.aws_secret_access_key
|
||||
if hasattr(self.config, "aws_region"):
|
||||
aws_region = self.config.aws_region
|
||||
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key if aws_access_key else None,
|
||||
aws_secret_access_key=aws_secret_key if aws_secret_key else None,
|
||||
)
|
||||
|
||||
self.model_kwargs = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens_to_sample": self.config.max_tokens,
|
||||
@@ -34,13 +64,14 @@ class AWSBedrockLLM(LLMBase):
|
||||
Returns:
|
||||
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
||||
"""
|
||||
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"\n\n{role}: {content}")
|
||||
|
||||
return "".join(formatted_messages) + "\n\nAssistant:"
|
||||
return "\n\nHuman: " + "".join(formatted_messages) + "\n\nAssistant:"
|
||||
|
||||
def _parse_response(self, response, tools) -> str:
|
||||
"""
|
||||
@@ -68,8 +99,9 @@ class AWSBedrockLLM(LLMBase):
|
||||
|
||||
return processed_response
|
||||
|
||||
response_body = json.loads(response["body"].read().decode())
|
||||
return response_body.get("completion", "")
|
||||
response_body = response.get("body").read().decode()
|
||||
response_json = json.loads(response_body)
|
||||
return response_json.get("content", [{"text": ""}])[0].get("text", "")
|
||||
|
||||
def _prepare_input(
|
||||
self,
|
||||
@@ -113,9 +145,9 @@ class AWSBedrockLLM(LLMBase):
|
||||
input_body = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
|
||||
"topP": model_kwargs.get("top_p"),
|
||||
"temperature": model_kwargs.get("temperature"),
|
||||
"maxTokenCount": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"topP": self.model_kwargs["top_p"] or 0.9,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
},
|
||||
}
|
||||
input_body["textGenerationConfig"] = {
|
||||
@@ -206,15 +238,40 @@ class AWSBedrockLLM(LLMBase):
|
||||
else:
|
||||
# Use invoke_model method when no tools are provided
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.model.split(".")[0]
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||
provider = extract_provider(self.config.model)
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, model_kwargs=self.model_kwargs)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
if provider == "anthropic" or provider == "deepseek":
|
||||
|
||||
input_body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}]
|
||||
}
|
||||
],
|
||||
"max_tokens": self.model_kwargs["max_tokens_to_sample"] or self.model_kwargs["max_tokens"] or 5000,
|
||||
"temperature": self.model_kwargs["temperature"] or 0.1,
|
||||
"top_p": self.model_kwargs["top_p"] or 0.9,
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
}
|
||||
|
||||
body = json.dumps(input_body)
|
||||
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
else:
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
Reference in New Issue
Block a user