Support model config in LLMs (#1495)
This commit is contained in:
@@ -5,12 +5,16 @@ from typing import Dict, List, Optional, Any
|
||||
import boto3
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, model="cohere.command-r-v1:0"):
|
||||
if not self.config.model:
|
||||
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
|
||||
self.model = model
|
||||
self.model_kwargs = {"temperature": self.config.temperature, "max_tokens_to_sample": self.config.max_tokens, "top_p": self.config.top_p}
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
@@ -171,19 +175,20 @@ class AWSBedrockLLM(LLMBase):
|
||||
if tools:
|
||||
# Use converse method when tools are provided
|
||||
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
|
||||
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
|
||||
tools_config = {"tools": self._convert_tool_format(tools)}
|
||||
|
||||
response = self.client.converse(
|
||||
modelId=self.model,
|
||||
modelId=self.config.model,
|
||||
messages=messages,
|
||||
inferenceConfig=inference_config,
|
||||
toolConfig=tools_config
|
||||
)
|
||||
print("Tools response: ", response)
|
||||
else:
|
||||
# Use invoke_model method when no tools are provided
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.model.split(".")[0]
|
||||
input_body = self._prepare_input(provider, self.model, prompt)
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = self.client.invoke_model(
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class LLMBase(ABC):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""Initialize a base LLM class
|
||||
|
||||
:param config: LLM configuration option class, defaults to None
|
||||
:type config: Optional[BaseLlmConfig], optional
|
||||
"""
|
||||
if config is None:
|
||||
self.config = BaseLlmConfig()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def generate_response(self, messages):
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,7 @@ class LlmConfig(BaseModel):
|
||||
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
|
||||
)
|
||||
config: Optional[dict] = Field(
|
||||
description="Configuration for the specific LLM", default=None
|
||||
description="Configuration for the specific LLM", default={}
|
||||
)
|
||||
|
||||
@field_validator("config")
|
||||
|
||||
@@ -4,12 +4,16 @@ from typing import Dict, List, Optional
|
||||
from groq import Groq
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class GroqLLM(LLMBase):
|
||||
def __init__(self, model="llama3-70b-8192"):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="llama3-70b-8192"
|
||||
self.client = Groq()
|
||||
self.model = model
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
@@ -58,7 +62,13 @@ class GroqLLM(LLMBase):
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {"model": self.model, "messages": messages}
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
|
||||
@@ -4,11 +4,15 @@ from typing import Dict, List, Optional
|
||||
import litellm
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class LiteLLM(LLMBase):
|
||||
def __init__(self, model="gpt-4o"):
|
||||
self.model = model
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="gpt-4o"
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
@@ -57,10 +61,16 @@ class LiteLLM(LLMBase):
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
if not litellm.supports_function_calling(self.model):
|
||||
raise ValueError(f"Model '{self.model}' in litellm does not support function calling.")
|
||||
if not litellm.supports_function_calling(self.config.model):
|
||||
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||
|
||||
params = {"model": self.model, "messages": messages}
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
|
||||
@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
|
||||
from openai import OpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
class OpenAILLM(LLMBase):
|
||||
def __init__(self, model="gpt-4o"):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="gpt-4o"
|
||||
self.client = OpenAI()
|
||||
self.model = model
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
@@ -58,7 +61,13 @@ class OpenAILLM(LLMBase):
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {"model": self.model, "messages": messages}
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
|
||||
@@ -4,12 +4,15 @@ from typing import Dict, List, Optional
|
||||
from together import Together
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
class TogetherLLM(LLMBase):
|
||||
def __init__(self, model="mistralai/Mixtral-8x7B-Instruct-v0.1"):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
self.client = Together()
|
||||
self.model = model
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
@@ -58,7 +61,13 @@ class TogetherLLM(LLMBase):
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {"model": self.model, "messages": messages}
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
|
||||
Reference in New Issue
Block a user