Add AWS Bedrock support (#1482)

This commit is contained in:
Dev Khant
2024-07-18 03:08:10 +05:30
committed by GitHub
parent 4e5d34103f
commit 1e7618dfa4
12 changed files with 454 additions and 19 deletions

196
mem0/llms/aws_bedrock.py Normal file
View File

@@ -0,0 +1,196 @@
import os
import json
from typing import Dict, List, Optional, Any
import boto3
from mem0.llms.base import LLMBase
class AWSBedrockLLM(LLMBase):
def __init__(self, model="cohere.command-r-v1:0"):
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
self.model = model
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""
Formats a list of messages into the required prompt structure for the model.
Args:
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
Each dictionary contains 'role' and 'content' keys.
Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
"""
formatted_messages = []
for message in messages:
role = message['role'].capitalize()
content = message['content']
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response, tools) -> str:
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"tool_calls": []
}
if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]:
if "toolUse" in item:
processed_response["tool_calls"].append({
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"]
})
return processed_response
response_body = json.loads(response['body'].read().decode())
return response_body.get('completion', '')
def _prepare_input(
self,
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements.
Args:
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Returns:
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
"""
input_body = {"prompt": prompt, **model_kwargs}
provider_mappings = {
"meta": {"max_tokens_to_sample": "max_gen_len"},
"ai21": {"max_tokens_to_sample": "maxTokens", "top_p": "topP"},
"mistral": {"max_tokens_to_sample": "max_tokens"},
"cohere": {"max_tokens_to_sample": "max_tokens", "top_p": "p"},
}
if provider in provider_mappings:
for old_key, new_key in provider_mappings[provider].items():
if old_key in input_body:
input_body[new_key] = input_body.pop(old_key)
if provider == "cohere" and "cohere.command-r" in model:
input_body["message"] = input_body.pop("prompt")
if provider == "amazon":
input_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature")
}
}
input_body["textGenerationConfig"] = {k: v for k, v in input_body["textGenerationConfig"].items() if v is not None}
return input_body
def _convert_tool_format(self, original_tools):
"""
Converts a list of tools from their original format to a new standardized format.
Args:
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details.
Returns:
list: A list of dictionaries representing the tools in the new standardized format.
"""
new_tools = []
for tool in original_tools:
if tool['type'] == 'function':
function = tool['function']
new_tool = {
"toolSpec": {
"name": function['name'],
"description": function['description'],
"inputSchema": {
"json": {
"type": "object",
"properties": {},
"required": function['parameters'].get('required', [])
}
}
}
}
for prop, details in function['parameters'].get('properties', {}).items():
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get('type', 'string'),
"description": details.get('description', '')
}
new_tools.append(new_tool)
return new_tools
def generate_response(
self,
messages: List[Dict[str, str]],
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using AWS Bedrock.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
if tools:
# Use converse method when tools are provided
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse(
modelId=self.model,
messages=messages,
toolConfig=tools_config
)
print("Tools response: ", response)
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.model, prompt)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept='application/json',
contentType='application/json'
)
return self._parse_response(response, tools)

View File

@@ -1,3 +1,4 @@
import json
from typing import Dict, List, Optional
from groq import Groq
@@ -10,6 +11,34 @@ class GroqLLM(LLMBase):
self.client = Groq()
self.model = model
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
@@ -37,4 +66,4 @@ class GroqLLM(LLMBase):
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response
return self._parse_response(response, tools)

View File

@@ -1,3 +1,4 @@
import json
from typing import Dict, List, Optional
from openai import OpenAI
@@ -9,6 +10,34 @@ class OpenAILLM(LLMBase):
def __init__(self, model="gpt-4o"):
self.client = OpenAI()
self.model = model
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
@@ -37,4 +66,4 @@ class OpenAILLM(LLMBase):
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response
return self._parse_response(response, tools)

View File

@@ -1,3 +1,4 @@
import json
from typing import Dict, List, Optional
from together import Together
@@ -9,6 +10,34 @@ class TogetherLLM(LLMBase):
def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
self.client = Together()
self.model = model
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
@@ -37,4 +66,4 @@ class TogetherLLM(LLMBase):
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response
return self._parse_response(response, tools)

View File

@@ -149,7 +149,6 @@ class Memory(MemoryBase):
{"role": "user", "content": prompt},
]
)
extracted_memories = extracted_memories.choices[0].message.content
existing_memories = self.vector_store.search(
name=self.collection_name,
query=embeddings,
@@ -176,8 +175,7 @@ class Memory(MemoryBase):
# Add tools for noop, add, update, delete memory.
tools = [ADD_MEMORY_TOOL, UPDATE_MEMORY_TOOL, DELETE_MEMORY_TOOL]
response = self.llm.generate_response(messages=messages, tools=tools)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
tool_calls = response["tool_calls"]
response = []
if tool_calls:
@@ -188,9 +186,9 @@ class Memory(MemoryBase):
"delete_memory": self._delete_memory_tool,
}
for tool_call in tool_calls:
function_name = tool_call.function.name
function_name = tool_call["name"]
function_to_call = available_functions[function_name]
function_args = json.loads(tool_call.function.arguments)
function_args = tool_call["arguments"]
logging.info(
f"[openai_func] func: {function_name}, args: {function_args}"
)

View File

@@ -12,7 +12,8 @@ class LlmFactory:
"ollama": "mem0.llms.ollama.py.OllamaLLM",
"openai": "mem0.llms.openai.OpenAILLM",
"groq": "mem0.llms.groq.GroqLLM",
"together": "mem0.llms.together.TogetherLLM"
"together": "mem0.llms.together.TogetherLLM",
"aws_bedrock": "mem0.llms.aws_bedrock.AWSBedrockLLM"
}
@classmethod