Reverting the tools commit (#2404)

This commit is contained in:
Parshva Daftari
2025-03-20 00:09:00 +05:30
committed by GitHub
parent 1aed611539
commit ee66e0c954
21 changed files with 990 additions and 475 deletions

View File

@@ -4,26 +4,14 @@ from typing import Any, Dict, List, Optional
try:
import boto3
except ImportError:
raise ImportError(
"The 'boto3' library is required. Please install it using 'pip install boto3'."
)
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AWSBedrockLLM(LLMBase):
"""
A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AWS Bedrock LLM with the provided configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration object for the model.
"""
super().__init__(config)
if not self.config.model:
@@ -37,29 +25,49 @@ class AWSBedrockLLM(LLMBase):
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""
Formats a list of messages into a structured prompt for the model.
Formats a list of messages into the required prompt structure for the model.
Args:
messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
Each dictionary contains 'role' and 'content' keys.
Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
"""
formatted_messages = [
f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages
]
formatted_messages = []
for message in messages:
role = message["role"].capitalize()
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response) -> str:
def _parse_response(self, response, tools) -> str:
"""
Extracts the generated response from the API response.
Process the response based on whether tools are used or not.
Args:
response: The raw response from the AWS Bedrock API.
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str: The generated response text.
str or dict: The processed response.
"""
if tools:
processed_response = {"tool_calls": []}
if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]:
if "toolUse" in item:
processed_response["tool_calls"].append(
{
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"],
}
)
return processed_response
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
@@ -68,21 +76,22 @@ class AWSBedrockLLM(LLMBase):
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Prepares the input dictionary for the specified provider's model.
Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements.
Args:
provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The model identifier.
prompt (str): The input prompt.
model_kwargs (Optional[Dict[str, Any]]): Additional model parameters.
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Returns:
Dict[str, Any]: The prepared input dictionary.
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
"""
model_kwargs = model_kwargs or {}
input_body = {"prompt": prompt, **model_kwargs}
provider_mappings = {
@@ -110,35 +119,102 @@ class AWSBedrockLLM(LLMBase):
},
}
input_body["textGenerationConfig"] = {
k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
}
return input_body
def generate_response(self, messages: List[Dict[str, str]]) -> str:
def _convert_tool_format(self, original_tools):
"""
Generates a response using AWS Bedrock based on the provided messages.
Converts a list of tools from their original format to a new standardized format.
Args:
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details.
Returns:
str: The generated response text.
list: A list of dictionaries representing the tools in the new standardized format.
"""
prompt = self._format_messages(messages)
provider = self.config.model.split(".")[0]
input_body = self._prepare_input(
provider, self.config.model, prompt, self.model_kwargs
)
body = json.dumps(input_body)
new_tools = []
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
for tool in original_tools:
if tool["type"] == "function":
function = tool["function"]
new_tool = {
"toolSpec": {
"name": function["name"],
"description": function["description"],
"inputSchema": {
"json": {
"type": "object",
"properties": {},
"required": function["parameters"].get("required", []),
}
},
}
}
return self._parse_response(response)
for prop, details in function["parameters"].get("properties", {}).items():
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get("type", "string"),
"description": details.get("description", ""),
}
new_tools.append(new_tool)
return new_tools
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using AWS Bedrock.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
if tools:
# Use converse method when tools are provided
messages = [
{
"role": "user",
"content": [{"text": message["content"]} for message in messages],
}
]
inference_config = {
"temperature": self.model_kwargs["temperature"],
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
"topP": self.model_kwargs["top_p"],
}
tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse(
modelId=self.config.model,
messages=messages,
inferenceConfig=inference_config,
toolConfig=tools_config,
)
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)