Remove tools from LLMs (#2363)
This commit is contained in:
@@ -4,14 +4,26 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||
raise ImportError(
|
||||
"The 'anthropic' library is required. Please install it using 'pip install anthropic'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AnthropicLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with Anthropic's Claude models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the AnthropicLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
@@ -23,23 +35,17 @@ class AnthropicLLM(LLMBase):
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Anthropic.
|
||||
Generates a response using Anthropic's Claude model based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
# Separate system message from other messages
|
||||
# Extract system message separately
|
||||
system_message = ""
|
||||
filtered_messages = []
|
||||
for message in messages:
|
||||
@@ -56,9 +62,6 @@ class AnthropicLLM(LLMBase):
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.messages.create(**params)
|
||||
return response.content[0].text
|
||||
|
||||
@@ -4,14 +4,26 @@ from typing import Any, Dict, List, Optional
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||
raise ImportError(
|
||||
"The 'boto3' library is required. Please install it using 'pip install boto3'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
"""
|
||||
A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the AWS Bedrock LLM with the provided configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration object for the model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
@@ -25,49 +37,29 @@ class AWSBedrockLLM(LLMBase):
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Formats a list of messages into the required prompt structure for the model.
|
||||
Formats a list of messages into a structured prompt for the model.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
|
||||
Each dictionary contains 'role' and 'content' keys.
|
||||
messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
|
||||
|
||||
Returns:
|
||||
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
||||
"""
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"\n\n{role}: {content}")
|
||||
|
||||
formatted_messages = [
|
||||
f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages
|
||||
]
|
||||
return "".join(formatted_messages) + "\n\nAssistant:"
|
||||
|
||||
def _parse_response(self, response, tools) -> str:
|
||||
def _parse_response(self, response) -> str:
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
Extracts the generated response from the API response.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
response: The raw response from the AWS Bedrock API.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
str: The generated response text.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {"tool_calls": []}
|
||||
|
||||
if response["output"]["message"]["content"]:
|
||||
for item in response["output"]["message"]["content"]:
|
||||
if "toolUse" in item:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": item["toolUse"]["input"],
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
response_body = json.loads(response["body"].read().decode())
|
||||
return response_body.get("completion", "")
|
||||
|
||||
@@ -76,22 +68,21 @@ class AWSBedrockLLM(LLMBase):
|
||||
provider: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_kwargs: Optional[Dict[str, Any]] = {},
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepares the input dictionary for the specified provider's model by mapping and renaming
|
||||
keys in the input based on the provider's requirements.
|
||||
Prepares the input dictionary for the specified provider's model.
|
||||
|
||||
Args:
|
||||
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
|
||||
model (str): The name or identifier of the model being used.
|
||||
prompt (str): The text prompt to be processed by the model.
|
||||
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
|
||||
provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
|
||||
model (str): The model identifier.
|
||||
prompt (str): The input prompt.
|
||||
model_kwargs (Optional[Dict[str, Any]]): Additional model parameters.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
|
||||
Dict[str, Any]: The prepared input dictionary.
|
||||
"""
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
input_body = {"prompt": prompt, **model_kwargs}
|
||||
|
||||
provider_mappings = {
|
||||
@@ -119,102 +110,35 @@ class AWSBedrockLLM(LLMBase):
|
||||
},
|
||||
}
|
||||
input_body["textGenerationConfig"] = {
|
||||
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
|
||||
k: v
|
||||
for k, v in input_body["textGenerationConfig"].items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
return input_body
|
||||
|
||||
def _convert_tool_format(self, original_tools):
|
||||
def generate_response(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Converts a list of tools from their original format to a new standardized format.
|
||||
Generates a response using AWS Bedrock based on the provided messages.
|
||||
|
||||
Args:
|
||||
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details.
|
||||
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
|
||||
|
||||
Returns:
|
||||
list: A list of dictionaries representing the tools in the new standardized format.
|
||||
str: The generated response text.
|
||||
"""
|
||||
new_tools = []
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.config.model.split(".")[0]
|
||||
input_body = self._prepare_input(
|
||||
provider, self.config.model, prompt, self.model_kwargs
|
||||
)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
for tool in original_tools:
|
||||
if tool["type"] == "function":
|
||||
function = tool["function"]
|
||||
new_tool = {
|
||||
"toolSpec": {
|
||||
"name": function["name"],
|
||||
"description": function["description"],
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": function["parameters"].get("required", []),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.config.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
for prop, details in function["parameters"].get("properties", {}).items():
|
||||
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
|
||||
"type": details.get("type", "string"),
|
||||
"description": details.get("description", ""),
|
||||
}
|
||||
|
||||
new_tools.append(new_tool)
|
||||
|
||||
return new_tools
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
"""
|
||||
Generate a response based on the given messages using AWS Bedrock.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
|
||||
if tools:
|
||||
# Use converse method when tools are provided
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": message["content"]} for message in messages],
|
||||
}
|
||||
]
|
||||
inference_config = {
|
||||
"temperature": self.model_kwargs["temperature"],
|
||||
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
|
||||
"topP": self.model_kwargs["top_p"],
|
||||
}
|
||||
tools_config = {"tools": self._convert_tool_format(tools)}
|
||||
|
||||
response = self.client.converse(
|
||||
modelId=self.config.model,
|
||||
messages=messages,
|
||||
inferenceConfig=inference_config,
|
||||
toolConfig=tools_config,
|
||||
)
|
||||
else:
|
||||
# Use invoke_model method when no tools are provided
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.model.split(".")[0]
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.model,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
return self._parse_response(response)
|
||||
|
||||
@@ -9,17 +9,35 @@ from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AzureOpenAILLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with Azure OpenAI models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the AzureOpenAILLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
# Ensure model name is set; it should match the Azure OpenAI deployment name.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o"
|
||||
|
||||
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
||||
api_key = self.config.azure_kwargs.api_key or os.getenv(
|
||||
"LLM_AZURE_OPENAI_API_KEY"
|
||||
)
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv(
|
||||
"LLM_AZURE_DEPLOYMENT"
|
||||
)
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv(
|
||||
"LLM_AZURE_ENDPOINT"
|
||||
)
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv(
|
||||
"LLM_AZURE_API_VERSION"
|
||||
)
|
||||
default_headers = self.config.azure_kwargs.default_headers
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
@@ -31,54 +49,20 @@ class AzureOpenAILLM(LLMBase):
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Azure OpenAI.
|
||||
Generates a response using Azure OpenAI based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -87,11 +71,8 @@ class AzureOpenAILLM(LLMBase):
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -9,20 +9,38 @@ from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AzureOpenAIStructuredLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with Azure OpenAI models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the AzureOpenAIStructuredLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
# Ensure model name is set; it should match the Azure OpenAI deployment name.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o-2024-08-06"
|
||||
|
||||
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
|
||||
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
|
||||
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
|
||||
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
|
||||
api_key = (
|
||||
os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
|
||||
)
|
||||
azure_deployment = (
|
||||
os.getenv("LLM_AZURE_DEPLOYMENT")
|
||||
or self.config.azure_kwargs.azure_deployment
|
||||
)
|
||||
azure_endpoint = (
|
||||
os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
|
||||
)
|
||||
api_version = (
|
||||
os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
|
||||
)
|
||||
default_headers = self.config.azure_kwargs.default_headers
|
||||
|
||||
# Can display a warning if API version is of model and api-version
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
@@ -32,54 +50,20 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
||||
default_headers=default_headers,
|
||||
)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Azure OpenAI.
|
||||
Generates a response using Azure OpenAI based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -88,11 +72,9 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -4,8 +4,12 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class LlmConfig(BaseModel):
|
||||
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
|
||||
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
|
||||
provider: str = Field(
|
||||
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
|
||||
)
|
||||
config: Optional[dict] = Field(
|
||||
description="Configuration for the specific LLM", default={}
|
||||
)
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
|
||||
@@ -9,64 +9,42 @@ from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class DeepSeekLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with DeepSeek's language models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the DeepSeekLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "deepseek-chat"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||
base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com"
|
||||
base_url = (
|
||||
self.config.deepseek_base_url
|
||||
or os.getenv("DEEPSEEK_API_BASE")
|
||||
or "https://api.deepseek.com"
|
||||
)
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using DeepSeek.
|
||||
Generates a response using DeepSeek based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -75,10 +53,5 @@ class DeepSeekLLM(LLMBase):
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -3,8 +3,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
from google.generativeai import GenerativeModel, protos
|
||||
from google.generativeai.types import content_types
|
||||
from google.generativeai import GenerativeModel
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
||||
@@ -15,7 +14,17 @@ from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class GeminiLLM(LLMBase):
|
||||
"""
|
||||
A wrapper for Google's Gemini language model, integrating it with the LLMBase class.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the Gemini LLM with the provided configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration object for the model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
@@ -25,51 +34,25 @@ class GeminiLLM(LLMBase):
|
||||
genai.configure(api_key=api_key)
|
||||
self.client = GenerativeModel(model_name=self.config.model)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
def _reformat_messages(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
Reformats messages to match the Gemini API's expected structure.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
messages (List[Dict[str, str]]): A list of messages with 'role' and 'content' keys.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
for part in response.candidates[0].content.parts:
|
||||
if fn := part.function_call:
|
||||
if isinstance(fn, protos.FunctionCall):
|
||||
fn_call = type(fn).to_dict(fn)
|
||||
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
|
||||
continue
|
||||
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.candidates[0].content.parts[0].text
|
||||
|
||||
def _reformat_messages(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Reformat messages for Gemini.
|
||||
|
||||
Args:
|
||||
messages: The list of messages provided in the request.
|
||||
|
||||
Returns:
|
||||
list: The list of messages in the required format.
|
||||
List[Dict[str, str]]: Reformatted messages in the required format.
|
||||
"""
|
||||
new_messages = []
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
|
||||
content = (
|
||||
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
)
|
||||
else:
|
||||
content = message["content"]
|
||||
|
||||
@@ -82,90 +65,33 @@ class GeminiLLM(LLMBase):
|
||||
|
||||
return new_messages
|
||||
|
||||
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
||||
"""
|
||||
Reformat tools for Gemini.
|
||||
|
||||
Args:
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
list: The list of tools in the required format.
|
||||
"""
|
||||
|
||||
def remove_additional_properties(data):
|
||||
"""Recursively removes 'additionalProperties' from nested dictionaries."""
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {
|
||||
key: remove_additional_properties(value)
|
||||
for key, value in data.items()
|
||||
if not (key == "additionalProperties")
|
||||
}
|
||||
return filtered_dict
|
||||
else:
|
||||
return data
|
||||
|
||||
new_tools = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
func = tool["function"].copy()
|
||||
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||
|
||||
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
|
||||
# return content_types.to_function_library(new_tools)
|
||||
|
||||
return new_tools
|
||||
else:
|
||||
return None
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
self, messages: List[Dict[str, str]], response_format: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Gemini.
|
||||
Generates a response from Gemini based on the given conversation history.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format for the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
|
||||
response_format (Optional[Dict]): Specifies the response format (e.g., JSON schema).
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response as text.
|
||||
"""
|
||||
|
||||
params = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_output_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
if response_format and response_format.get("type") == "json_object":
|
||||
params["response_mime_type"] = "application/json"
|
||||
if "schema" in response_format:
|
||||
params["response_schema"] = response_format["schema"]
|
||||
if tool_choice:
|
||||
tool_config = content_types.to_tool_config(
|
||||
{
|
||||
"function_calling_config": {
|
||||
"mode": tool_choice,
|
||||
"allowed_function_names": (
|
||||
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
response = self.client.generate_content(
|
||||
contents=self._reformat_messages(messages),
|
||||
tools=self._reformat_tools(tools),
|
||||
generation_config=genai.GenerationConfig(**params),
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
return response.candidates[0].content.parts[0].text
|
||||
|
||||
@@ -5,14 +5,26 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
from groq import Groq
|
||||
except ImportError:
|
||||
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
|
||||
raise ImportError(
|
||||
"The 'groq' library is required. Please install it using 'pip install groq'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class GroqLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with Groq's language models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the GroqLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
@@ -21,54 +33,20 @@ class GroqLLM(LLMBase):
|
||||
api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
|
||||
self.client = Groq(api_key=api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Groq.
|
||||
Generates a response using Groq based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -79,9 +57,5 @@ class GroqLLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -4,70 +4,50 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
|
||||
raise ImportError(
|
||||
"The 'litellm' library is required. Please install it using 'pip install litellm'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class LiteLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with LiteLLM's language models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the LiteLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o-mini"
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using Litellm.
|
||||
Generates a response using LiteLLM based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
if not litellm.supports_function_calling(self.config.model):
|
||||
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||
raise ValueError(
|
||||
f"Model '{self.config.model}' in LiteLLM does not support function calling."
|
||||
)
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -78,9 +58,6 @@ class LiteLLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -3,77 +3,56 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
from ollama import Client
|
||||
except ImportError:
|
||||
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
|
||||
raise ImportError(
|
||||
"The 'ollama' library is required. Please install it using 'pip install ollama'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OllamaLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with Ollama's language models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the OllamaLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "llama3.1:70b"
|
||||
|
||||
self.client = Client(host=self.config.ollama_base_url)
|
||||
self._ensure_model_exists()
|
||||
|
||||
def _ensure_model_exists(self):
|
||||
"""
|
||||
Ensure the specified model exists locally. If not, pull it from Ollama.
|
||||
Ensures the specified model exists locally. If not, pulls it from Ollama.
|
||||
"""
|
||||
local_models = self.client.list()["models"]
|
||||
if not any(model.get("name") == self.config.model for model in local_models):
|
||||
self.client.pull(self.config.model)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response["message"]["content"],
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response["message"].get("tool_calls"):
|
||||
for tool_call in response["message"]["tool_calls"]:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response["message"]["content"]
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using OpenAI.
|
||||
Generates a response using Ollama based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -87,8 +66,5 @@ class OllamaLLM(LLMBase):
|
||||
if response_format:
|
||||
params["format"] = "json"
|
||||
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
|
||||
response = self.client.chat(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response["message"]["content"]
|
||||
|
||||
@@ -9,7 +9,17 @@ from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OpenAILLM(LLMBase):
|
||||
"""
|
||||
A class to interact with OpenAI or OpenRouter APIs for generating responses using LLMs.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the OpenAILLM instance.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration for the LLM, including model, API key, and base URLs.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
@@ -24,57 +34,27 @@ class OpenAILLM(LLMBase):
|
||||
)
|
||||
else:
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
|
||||
base_url = (
|
||||
self.config.openai_base_url
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using OpenAI.
|
||||
Generates a response based on the provided messages using OpenAI or OpenRouter.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of message dictionaries containing 'role' and 'content'.
|
||||
response_format (Optional[str]): The format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -102,9 +82,6 @@ class OpenAILLM(LLMBase):
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -9,66 +9,45 @@ from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class OpenAIStructuredLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with OpenAI's structured language models using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the OpenAIStructuredLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o-2024-08-06"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
|
||||
base_url = (
|
||||
self.config.openai_base_url
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools (list, optional): List of tools that the model can call.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using OpenAI.
|
||||
Generates a response using OpenAI based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -78,10 +57,6 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools:
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.beta.chat.completions.parse(**params)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -5,14 +5,26 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
from together import Together
|
||||
except ImportError:
|
||||
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
|
||||
raise ImportError(
|
||||
"The 'together' library is required. Please install it using 'pip install together'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class TogetherLLM(LLMBase):
|
||||
"""
|
||||
A class for interacting with the TogetherAI language model using the specified configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
"""
|
||||
Initializes the TogetherLLM instance with the given configuration.
|
||||
|
||||
Args:
|
||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
@@ -21,54 +33,20 @@ class TogetherLLM(LLMBase):
|
||||
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||
self.client = Together(api_key=api_key)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
Args:
|
||||
response: The raw response from API.
|
||||
tools: The list of tools provided in the request.
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
response_format=None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
response_format: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a response based on the given messages using TogetherAI.
|
||||
Generates a response using TogetherAI based on the provided messages.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dicts containing 'role' and 'content'.
|
||||
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
str: The generated response from the model.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -79,9 +57,6 @@ class TogetherLLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
response = self.client.chat.completions.create(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return response.choices[0].message.content
|
||||
|
||||
@@ -15,7 +15,11 @@ class XAILLM(LLMBase):
|
||||
self.config.model = "grok-2-latest"
|
||||
|
||||
api_key = self.config.api_key or os.getenv("XAI_API_KEY")
|
||||
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
|
||||
base_url = (
|
||||
self.config.xai_base_url
|
||||
or os.getenv("XAI_API_BASE")
|
||||
or "https://api.x.ai/v1"
|
||||
)
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
def generate_response(self, messages: List[Dict[str, str]], response_format=None):
|
||||
|
||||
@@ -20,8 +20,10 @@ def mock_openai_client():
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_openai_client):
|
||||
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
|
||||
def test_generate_response(mock_openai_client):
|
||||
config = BaseLlmConfig(
|
||||
model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
|
||||
)
|
||||
llm = AzureOpenAILLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -29,67 +31,21 @@ def test_generate_response_without_tools(mock_openai_client):
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
||||
]
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages)
|
||||
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_openai_client):
|
||||
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
|
||||
llm = AzureOpenAILLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I've added the memory for you."
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "add_memory"
|
||||
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||
|
||||
mock_message.tool_calls = [mock_tool_call]
|
||||
mock_response.choices = [Mock(message=mock_message)]
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model=MODEL,
|
||||
messages=messages,
|
||||
temperature=TEMPERATURE,
|
||||
max_tokens=MAX_TOKENS,
|
||||
top_p=TOP_P,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -128,4 +84,6 @@ def test_generate_with_http_proxies(default_headers):
|
||||
api_version=None,
|
||||
default_headers=default_headers,
|
||||
)
|
||||
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||
mock_http_client.assert_called_once_with(
|
||||
proxies="http://testproxy.mem0.net:8000"
|
||||
)
|
||||
|
||||
@@ -16,14 +16,26 @@ def mock_deepseek_client():
|
||||
|
||||
def test_deepseek_llm_base_url():
|
||||
# case1: default config with deepseek official base url
|
||||
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="deepseek-chat",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
api_key="api_key",
|
||||
)
|
||||
llm = DeepSeekLLM(config)
|
||||
assert str(llm.client.base_url) == "https://api.deepseek.com"
|
||||
|
||||
# case2: with env variable DEEPSEEK_API_BASE
|
||||
provider_base_url = "https://api.provider.com/v1/"
|
||||
os.environ["DEEPSEEK_API_BASE"] = provider_base_url
|
||||
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="deepseek-chat",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
api_key="api_key",
|
||||
)
|
||||
llm = DeepSeekLLM(config)
|
||||
assert str(llm.client.base_url) == provider_base_url
|
||||
|
||||
@@ -35,14 +47,16 @@ def test_deepseek_llm_base_url():
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
api_key="api_key",
|
||||
deepseek_base_url=config_base_url
|
||||
deepseek_base_url=config_base_url,
|
||||
)
|
||||
llm = DeepSeekLLM(config)
|
||||
assert str(llm.client.base_url) == config_base_url
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_deepseek_client):
|
||||
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
def test_generate_response(mock_deepseek_client):
|
||||
config = BaseLlmConfig(
|
||||
model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
llm = DeepSeekLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -50,64 +64,18 @@ def test_generate_response_without_tools(mock_deepseek_client):
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
||||
]
|
||||
mock_deepseek_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages)
|
||||
|
||||
mock_deepseek_client.chat.completions.create.assert_called_once_with(
|
||||
model="deepseek-chat", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_deepseek_client):
|
||||
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = DeepSeekLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I've added the memory for you."
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "add_memory"
|
||||
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||
|
||||
mock_message.tool_calls = [mock_tool_call]
|
||||
mock_response.choices = [Mock(message=mock_message)]
|
||||
mock_deepseek_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_deepseek_client.chat.completions.create.assert_called_once_with(
|
||||
model="deepseek-chat",
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=tools,
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
@@ -17,7 +17,9 @@ def mock_gemini_client():
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
config = BaseLlmConfig(
|
||||
model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
llm = GeminiLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -34,86 +36,14 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||
|
||||
mock_gemini_client.generate_content.assert_called_once_with(
|
||||
contents=[
|
||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||
{
|
||||
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
|
||||
"role": "user",
|
||||
},
|
||||
{"parts": "Hello, how are you?", "role": "user"},
|
||||
],
|
||||
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
||||
tools=None,
|
||||
tool_config=content_types.to_tool_config(
|
||||
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
||||
generation_config=GenerationConfig(
|
||||
temperature=0.7, max_output_tokens=100, top_p=1.0
|
||||
),
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = GeminiLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.name = "add_memory"
|
||||
mock_tool_call.args = {"data": "Today is a sunny day."}
|
||||
|
||||
mock_part = Mock()
|
||||
mock_part.function_call = mock_tool_call
|
||||
mock_part.text = "I've added the memory for you."
|
||||
|
||||
mock_content = Mock()
|
||||
mock_content.parts = [mock_part]
|
||||
|
||||
mock_message = Mock()
|
||||
mock_message.content = mock_content
|
||||
|
||||
mock_response = Mock(candidates=[mock_message])
|
||||
mock_gemini_client.generate_content.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_gemini_client.generate_content.assert_called_once_with(
|
||||
contents=[
|
||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
|
||||
],
|
||||
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
||||
tools=[
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
tool_config=content_types.to_tool_config(
|
||||
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
||||
),
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
|
||||
@@ -14,8 +14,10 @@ def mock_groq_client():
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_groq_client):
|
||||
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
def test_generate_response(mock_groq_client):
|
||||
config = BaseLlmConfig(
|
||||
model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
llm = GroqLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -23,64 +25,18 @@ def test_generate_response_without_tools(mock_groq_client):
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
||||
]
|
||||
mock_groq_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages)
|
||||
|
||||
mock_groq_client.chat.completions.create.assert_called_once_with(
|
||||
model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_groq_client):
|
||||
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = GroqLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I've added the memory for you."
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "add_memory"
|
||||
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||
|
||||
mock_message.tool_calls = [mock_tool_call]
|
||||
mock_response.choices = [Mock(message=mock_message)]
|
||||
mock_groq_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_groq_client.chat.completions.create.assert_called_once_with(
|
||||
model="llama3-70b-8192",
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
@@ -13,17 +13,22 @@ def mock_litellm():
|
||||
|
||||
|
||||
def test_generate_response_with_unsupported_model(mock_litellm):
|
||||
config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1)
|
||||
config = BaseLlmConfig(
|
||||
model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1
|
||||
)
|
||||
llm = litellm.LiteLLM(config)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
mock_litellm.supports_function_calling.return_value = False
|
||||
|
||||
with pytest.raises(ValueError, match="Model 'unsupported-model' in litellm does not support function calling."):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Model 'unsupported-model' in LiteLLM does not support function calling.",
|
||||
):
|
||||
llm.generate_response(messages)
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_litellm):
|
||||
def test_generate_response(mock_litellm):
|
||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
|
||||
llm = litellm.LiteLLM(config)
|
||||
messages = [
|
||||
@@ -32,7 +37,9 @@ def test_generate_response_without_tools(mock_litellm):
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
||||
]
|
||||
mock_litellm.completion.return_value = mock_response
|
||||
mock_litellm.supports_function_calling.return_value = True
|
||||
|
||||
@@ -42,50 +49,3 @@ def test_generate_response_without_tools(mock_litellm):
|
||||
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_litellm):
|
||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
|
||||
llm = litellm.LiteLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I've added the memory for you."
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "add_memory"
|
||||
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||
|
||||
mock_message.tool_calls = [mock_tool_call]
|
||||
mock_response.choices = [Mock(message=mock_message)]
|
||||
mock_litellm.completion.return_value = mock_response
|
||||
mock_litellm.supports_function_calling.return_value = True
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_litellm.completion.assert_called_once_with(
|
||||
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1, tools=tools, tool_choice="auto"
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
|
||||
@@ -16,7 +16,9 @@ def mock_openai_client():
|
||||
|
||||
def test_openai_llm_base_url():
|
||||
# case1: default config: with openai official base url
|
||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
|
||||
)
|
||||
llm = OpenAILLM(config)
|
||||
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
||||
assert str(llm.client.base_url) == "https://api.openai.com/v1/"
|
||||
@@ -24,7 +26,9 @@ def test_openai_llm_base_url():
|
||||
# case2: with env variable OPENAI_API_BASE
|
||||
provider_base_url = "https://api.provider.com/v1"
|
||||
os.environ["OPENAI_API_BASE"] = provider_base_url
|
||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||
config = BaseLlmConfig(
|
||||
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
|
||||
)
|
||||
llm = OpenAILLM(config)
|
||||
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
||||
assert str(llm.client.base_url) == provider_base_url + "/"
|
||||
@@ -32,14 +36,19 @@ def test_openai_llm_base_url():
|
||||
# case3: with config.openai_base_url
|
||||
config_base_url = "https://api.config.com/v1"
|
||||
config = BaseLlmConfig(
|
||||
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", openai_base_url=config_base_url
|
||||
model="gpt-4o",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
api_key="api_key",
|
||||
openai_base_url=config_base_url,
|
||||
)
|
||||
llm = OpenAILLM(config)
|
||||
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
||||
assert str(llm.client.base_url) == config_base_url + "/"
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_openai_client):
|
||||
def test_generate_response(mock_openai_client):
|
||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = OpenAILLM(config)
|
||||
messages = [
|
||||
@@ -48,7 +57,9 @@ def test_generate_response_without_tools(mock_openai_client):
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
||||
]
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages)
|
||||
@@ -57,49 +68,3 @@ def test_generate_response_without_tools(mock_openai_client):
|
||||
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_openai_client):
|
||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = OpenAILLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I've added the memory for you."
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "add_memory"
|
||||
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||
|
||||
mock_message.tool_calls = [mock_tool_call]
|
||||
mock_response.choices = [Mock(message=mock_message)]
|
||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
|
||||
@@ -14,8 +14,13 @@ def mock_together_client():
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_together_client):
|
||||
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
def test_generate_response(mock_together_client):
|
||||
config = BaseLlmConfig(
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
)
|
||||
llm = TogetherLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -23,64 +28,18 @@ def test_generate_response_without_tools(mock_together_client):
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||
mock_response.choices = [
|
||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
||||
]
|
||||
mock_together_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages)
|
||||
|
||||
mock_together_client.chat.completions.create.assert_called_once_with(
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||
)
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_together_client):
|
||||
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = TogetherLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||
]
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I've added the memory for you."
|
||||
|
||||
mock_tool_call = Mock()
|
||||
mock_tool_call.function.name = "add_memory"
|
||||
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||
|
||||
mock_message.tool_calls = [mock_tool_call]
|
||||
mock_response.choices = [Mock(message=mock_message)]
|
||||
mock_together_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
response = llm.generate_response(messages, tools=tools)
|
||||
|
||||
mock_together_client.chat.completions.create.assert_called_once_with(
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
max_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
assert len(response["tool_calls"]) == 1
|
||||
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
Reference in New Issue
Block a user