Support model config in LLMs (#1495)

This commit is contained in:
Dev Khant
2024-07-18 21:51:40 +05:30
committed by GitHub
parent c411dc294e
commit 40c9abe484
15 changed files with 172 additions and 41 deletions

View File

@@ -5,12 +5,16 @@ from typing import Dict, List, Optional, Any
import boto3
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
class AWSBedrockLLM(LLMBase):
def __init__(self, model="cohere.command-r-v1:0"):
if not self.config.model:
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
self.model = model
self.model_kwargs = {"temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, "top_p": self.config.top_p}
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""
@@ -171,19 +175,20 @@ class AWSBedrockLLM(LLMBase):
if tools:
# Use converse method when tools are provided
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse(
modelId=self.model,
modelId=self.config.model,
messages=messages,
inferenceConfig=inference_config,
toolConfig=tools_config
)
print("Tools response: ", response)
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.model, prompt)
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(