Remove tools from LLMs (#2363)

This commit is contained in:
Anusha Yella
2025-03-14 17:42:48 +05:30
committed by GitHub
parent 4be426f762
commit ee80a43810
21 changed files with 418 additions and 1071 deletions

View File

@@ -4,14 +4,26 @@ from typing import Dict, List, Optional
try: try:
import anthropic import anthropic
except ImportError: except ImportError:
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.") raise ImportError(
"The 'anthropic' library is required. Please install it using 'pip install anthropic'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
class AnthropicLLM(LLMBase): class AnthropicLLM(LLMBase):
"""
A class for interacting with Anthropic's Claude models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AnthropicLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
@@ -23,23 +35,17 @@ class AnthropicLLM(LLMBase):
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, ) -> str:
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using Anthropic. Generates a response using Anthropic's Claude model based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
# Separate system message from other messages # Extract system message separately
system_message = "" system_message = ""
filtered_messages = [] filtered_messages = []
for message in messages: for message in messages:
@@ -56,9 +62,6 @@ class AnthropicLLM(LLMBase):
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p, "top_p": self.config.top_p,
} }
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.messages.create(**params) response = self.client.messages.create(**params)
return response.content[0].text return response.content[0].text

View File

@@ -4,14 +4,26 @@ from typing import Any, Dict, List, Optional
try: try:
import boto3 import boto3
except ImportError: except ImportError:
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.") raise ImportError(
"The 'boto3' library is required. Please install it using 'pip install boto3'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
class AWSBedrockLLM(LLMBase): class AWSBedrockLLM(LLMBase):
"""
A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AWS Bedrock LLM with the provided configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration object for the model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
@@ -25,49 +37,29 @@ class AWSBedrockLLM(LLMBase):
def _format_messages(self, messages: List[Dict[str, str]]) -> str: def _format_messages(self, messages: List[Dict[str, str]]) -> str:
""" """
Formats a list of messages into the required prompt structure for the model. Formats a list of messages into a structured prompt for the model.
Args: Args:
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message. messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
Each dictionary contains 'role' and 'content' keys.
Returns: Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines. str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
""" """
formatted_messages = [] formatted_messages = [
for message in messages: f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages
role = message["role"].capitalize() ]
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:" return "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response, tools) -> str: def _parse_response(self, response) -> str:
""" """
Process the response based on whether tools are used or not. Extracts the generated response from the API response.
Args: Args:
response: The raw response from API. response: The raw response from the AWS Bedrock API.
tools: The list of tools provided in the request.
Returns: Returns:
str or dict: The processed response. str: The generated response text.
""" """
if tools:
processed_response = {"tool_calls": []}
if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]:
if "toolUse" in item:
processed_response["tool_calls"].append(
{
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"],
}
)
return processed_response
response_body = json.loads(response["body"].read().decode()) response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "") return response_body.get("completion", "")
@@ -76,22 +68,21 @@ class AWSBedrockLLM(LLMBase):
provider: str, provider: str,
model: str, model: str,
prompt: str, prompt: str,
model_kwargs: Optional[Dict[str, Any]] = {}, model_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Prepares the input dictionary for the specified provider's model by mapping and renaming Prepares the input dictionary for the specified provider's model.
keys in the input based on the provider's requirements.
Args: Args:
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon"). provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used. model (str): The model identifier.
prompt (str): The text prompt to be processed by the model. prompt (str): The input prompt.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements. model_kwargs (Optional[Dict[str, Any]]): Additional model parameters.
Returns: Returns:
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider. Dict[str, Any]: The prepared input dictionary.
""" """
model_kwargs = model_kwargs or {}
input_body = {"prompt": prompt, **model_kwargs} input_body = {"prompt": prompt, **model_kwargs}
provider_mappings = { provider_mappings = {
@@ -119,102 +110,35 @@ class AWSBedrockLLM(LLMBase):
}, },
} }
input_body["textGenerationConfig"] = { input_body["textGenerationConfig"] = {
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
} }
return input_body return input_body
def _convert_tool_format(self, original_tools): def generate_response(self, messages: List[Dict[str, str]]) -> str:
""" """
Converts a list of tools from their original format to a new standardized format. Generates a response using AWS Bedrock based on the provided messages.
Args: Args:
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details. messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
Returns: Returns:
list: A list of dictionaries representing the tools in the new standardized format. str: The generated response text.
""" """
new_tools = [] prompt = self._format_messages(messages)
provider = self.config.model.split(".")[0]
input_body = self._prepare_input(
provider, self.config.model, prompt, self.model_kwargs
)
body = json.dumps(input_body)
for tool in original_tools: response = self.client.invoke_model(
if tool["type"] == "function": body=body,
function = tool["function"] modelId=self.config.model,
new_tool = { accept="application/json",
"toolSpec": { contentType="application/json",
"name": function["name"], )
"description": function["description"],
"inputSchema": {
"json": {
"type": "object",
"properties": {},
"required": function["parameters"].get("required", []),
}
},
}
}
for prop, details in function["parameters"].get("properties", {}).items(): return self._parse_response(response)
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get("type", "string"),
"description": details.get("description", ""),
}
new_tools.append(new_tool)
return new_tools
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using AWS Bedrock.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
if tools:
# Use converse method when tools are provided
messages = [
{
"role": "user",
"content": [{"text": message["content"]} for message in messages],
}
]
inference_config = {
"temperature": self.model_kwargs["temperature"],
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
"topP": self.model_kwargs["top_p"],
}
tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse(
modelId=self.config.model,
messages=messages,
inferenceConfig=inference_config,
toolConfig=tools_config,
)
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)

View File

@@ -9,17 +9,35 @@ from mem0.llms.base import LLMBase
class AzureOpenAILLM(LLMBase): class AzureOpenAILLM(LLMBase):
"""
A class for interacting with Azure OpenAI models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AzureOpenAILLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
# Model name should match the custom deployment name chosen for it. # Ensure model name is set; it should match the Azure OpenAI deployment name.
if not self.config.model: if not self.config.model:
self.config.model = "gpt-4o" self.config.model = "gpt-4o"
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY") api_key = self.config.azure_kwargs.api_key or os.getenv(
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT") "LLM_AZURE_OPENAI_API_KEY"
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT") )
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION") azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv(
"LLM_AZURE_DEPLOYMENT"
)
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv(
"LLM_AZURE_ENDPOINT"
)
api_version = self.config.azure_kwargs.api_version or os.getenv(
"LLM_AZURE_API_VERSION"
)
default_headers = self.config.azure_kwargs.default_headers default_headers = self.config.azure_kwargs.default_headers
self.client = AzureOpenAI( self.client = AzureOpenAI(
@@ -31,54 +49,20 @@ class AzureOpenAILLM(LLMBase):
default_headers=default_headers, default_headers=default_headers,
) )
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using Azure OpenAI. Generates a response using Azure OpenAI based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -87,11 +71,8 @@ class AzureOpenAILLM(LLMBase):
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p, "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools) return response.choices[0].message.content

View File

@@ -9,20 +9,38 @@ from mem0.llms.base import LLMBase
class AzureOpenAIStructuredLLM(LLMBase): class AzureOpenAIStructuredLLM(LLMBase):
"""
A class for interacting with Azure OpenAI models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AzureOpenAIStructuredLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
# Model name should match the custom deployment name chosen for it. # Ensure model name is set; it should match the Azure OpenAI deployment name.
if not self.config.model: if not self.config.model:
self.config.model = "gpt-4o-2024-08-06" self.config.model = "gpt-4o-2024-08-06"
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key api_key = (
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint )
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version azure_deployment = (
os.getenv("LLM_AZURE_DEPLOYMENT")
or self.config.azure_kwargs.azure_deployment
)
azure_endpoint = (
os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
)
api_version = (
os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
)
default_headers = self.config.azure_kwargs.default_headers default_headers = self.config.azure_kwargs.default_headers
# Can display a warning if API version is of model and api-version
self.client = AzureOpenAI( self.client = AzureOpenAI(
azure_deployment=azure_deployment, azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
@@ -32,54 +50,20 @@ class AzureOpenAIStructuredLLM(LLMBase):
default_headers=default_headers, default_headers=default_headers,
) )
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using Azure OpenAI. Generates a response using Azure OpenAI based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -88,11 +72,9 @@ class AzureOpenAIStructuredLLM(LLMBase):
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p, "top_p": self.config.top_p,
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools) return response.choices[0].message.content

View File

@@ -4,8 +4,12 @@ from pydantic import BaseModel, Field, field_validator
class LlmConfig(BaseModel): class LlmConfig(BaseModel):
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai") provider: str = Field(
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={}) description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
)
config: Optional[dict] = Field(
description="Configuration for the specific LLM", default={}
)
@field_validator("config") @field_validator("config")
def validate_config(cls, v, values): def validate_config(cls, v, values):

View File

@@ -9,64 +9,42 @@ from mem0.llms.base import LLMBase
class DeepSeekLLM(LLMBase): class DeepSeekLLM(LLMBase):
"""
A class for interacting with DeepSeek's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the DeepSeekLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model = "deepseek-chat" self.config.model = "deepseek-chat"
api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY") api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY")
base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com" base_url = (
self.config.deepseek_base_url
or os.getenv("DEEPSEEK_API_BASE")
or "https://api.deepseek.com"
)
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, ) -> str:
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using DeepSeek. Generates a response using DeepSeek based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -75,10 +53,5 @@ class DeepSeekLLM(LLMBase):
"max_tokens": self.config.max_tokens, "max_tokens": self.config.max_tokens,
"top_p": self.config.top_p, "top_p": self.config.top_p,
} }
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools) return response.choices[0].message.content

View File

@@ -3,8 +3,7 @@ from typing import Dict, List, Optional
try: try:
import google.generativeai as genai import google.generativeai as genai
from google.generativeai import GenerativeModel, protos from google.generativeai import GenerativeModel
from google.generativeai.types import content_types
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'." "The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
@@ -15,7 +14,17 @@ from mem0.llms.base import LLMBase
class GeminiLLM(LLMBase): class GeminiLLM(LLMBase):
"""
A wrapper for Google's Gemini language model, integrating it with the LLMBase class.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the Gemini LLM with the provided configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration object for the model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
@@ -25,51 +34,25 @@ class GeminiLLM(LLMBase):
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
self.client = GenerativeModel(model_name=self.config.model) self.client = GenerativeModel(model_name=self.config.model)
def _parse_response(self, response, tools): def _reformat_messages(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
""" """
Process the response based on whether tools are used or not. Reformats messages to match the Gemini API's expected structure.
Args: Args:
response: The raw response from API. messages (List[Dict[str, str]]): A list of messages with 'role' and 'content' keys.
tools: The list of tools provided in the request.
Returns: Returns:
str or dict: The processed response. List[Dict[str, str]]: Reformatted messages in the required format.
"""
if tools:
processed_response = {
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
"tool_calls": [],
}
for part in response.candidates[0].content.parts:
if fn := part.function_call:
if isinstance(fn, protos.FunctionCall):
fn_call = type(fn).to_dict(fn)
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
continue
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
return processed_response
else:
return response.candidates[0].content.parts[0].text
def _reformat_messages(self, messages: List[Dict[str, str]]):
"""
Reformat messages for Gemini.
Args:
messages: The list of messages provided in the request.
Returns:
list: The list of messages in the required format.
""" """
new_messages = [] new_messages = []
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] content = (
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
)
else: else:
content = message["content"] content = message["content"]
@@ -82,90 +65,33 @@ class GeminiLLM(LLMBase):
return new_messages return new_messages
def _reformat_tools(self, tools: Optional[List[Dict]]):
"""
Reformat tools for Gemini.
Args:
tools: The list of tools provided in the request.
Returns:
list: The list of tools in the required format.
"""
def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
new_tools = []
if tools:
for tool in tools:
func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
# return content_types.to_function_library(new_tools)
return new_tools
else:
return None
def generate_response( def generate_response(
self, self, messages: List[Dict[str, str]], response_format: Optional[Dict] = None
messages: List[Dict[str, str]], ) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using Gemini. Generates a response from Gemini based on the given conversation history.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
response_format (str or object, optional): Format for the response. Defaults to "text". response_format (Optional[Dict]): Specifies the response format (e.g., JSON schema).
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response as text.
""" """
params = { params = {
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_output_tokens": self.config.max_tokens, "max_output_tokens": self.config.max_tokens,
"top_p": self.config.top_p, "top_p": self.config.top_p,
} }
if response_format is not None and response_format["type"] == "json_object": if response_format and response_format.get("type") == "json_object":
params["response_mime_type"] = "application/json" params["response_mime_type"] = "application/json"
if "schema" in response_format: if "schema" in response_format:
params["response_schema"] = response_format["schema"] params["response_schema"] = response_format["schema"]
if tool_choice:
tool_config = content_types.to_tool_config(
{
"function_calling_config": {
"mode": tool_choice,
"allowed_function_names": (
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
),
}
}
)
response = self.client.generate_content( response = self.client.generate_content(
contents=self._reformat_messages(messages), contents=self._reformat_messages(messages),
tools=self._reformat_tools(tools),
generation_config=genai.GenerationConfig(**params), generation_config=genai.GenerationConfig(**params),
tool_config=tool_config,
) )
return self._parse_response(response, tools) return response.candidates[0].content.parts[0].text

View File

@@ -5,14 +5,26 @@ from typing import Dict, List, Optional
try: try:
from groq import Groq from groq import Groq
except ImportError: except ImportError:
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.") raise ImportError(
"The 'groq' library is required. Please install it using 'pip install groq'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
class GroqLLM(LLMBase): class GroqLLM(LLMBase):
"""
A class for interacting with Groq's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the GroqLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
@@ -21,54 +33,20 @@ class GroqLLM(LLMBase):
api_key = self.config.api_key or os.getenv("GROQ_API_KEY") api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
self.client = Groq(api_key=api_key) self.client = Groq(api_key=api_key)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using Groq. Generates a response using Groq based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -79,9 +57,5 @@ class GroqLLM(LLMBase):
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools) return response.choices[0].message.content

View File

@@ -4,70 +4,50 @@ from typing import Dict, List, Optional
try: try:
import litellm import litellm
except ImportError: except ImportError:
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.") raise ImportError(
"The 'litellm' library is required. Please install it using 'pip install litellm'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
class LiteLLM(LLMBase): class LiteLLM(LLMBase):
"""
A class for interacting with LiteLLM's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the LiteLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model = "gpt-4o-mini" self.config.model = "gpt-4o-mini"
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using Litellm. Generates a response using LiteLLM based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
if not litellm.supports_function_calling(self.config.model): if not litellm.supports_function_calling(self.config.model):
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.") raise ValueError(
f"Model '{self.config.model}' in LiteLLM does not support function calling."
)
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -78,9 +58,6 @@ class LiteLLM(LLMBase):
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = litellm.completion(**params) response = litellm.completion(**params)
return self._parse_response(response, tools) return response.choices[0].message.content

View File

@@ -3,77 +3,56 @@ from typing import Dict, List, Optional
try: try:
from ollama import Client from ollama import Client
except ImportError: except ImportError:
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.") raise ImportError(
"The 'ollama' library is required. Please install it using 'pip install ollama'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
class OllamaLLM(LLMBase): class OllamaLLM(LLMBase):
"""
A class for interacting with Ollama's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the OllamaLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model = "llama3.1:70b" self.config.model = "llama3.1:70b"
self.client = Client(host=self.config.ollama_base_url) self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists() self._ensure_model_exists()
def _ensure_model_exists(self): def _ensure_model_exists(self):
""" """
Ensure the specified model exists locally. If not, pull it from Ollama. Ensures the specified model exists locally. If not, pulls it from Ollama.
""" """
local_models = self.client.list()["models"] local_models = self.client.list()["models"]
if not any(model.get("name") == self.config.model for model in local_models): if not any(model.get("name") == self.config.model for model in local_models):
self.client.pull(self.config.model) self.client.pull(self.config.model)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response["message"]["content"],
"tool_calls": [],
}
if response["message"].get("tool_calls"):
for tool_call in response["message"]["tool_calls"]:
processed_response["tool_calls"].append(
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
)
return processed_response
else:
return response["message"]["content"]
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using OpenAI. Generates a response using Ollama based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -87,8 +66,5 @@ class OllamaLLM(LLMBase):
if response_format: if response_format:
params["format"] = "json" params["format"] = "json"
if tools:
params["tools"] = tools
response = self.client.chat(**params) response = self.client.chat(**params)
return self._parse_response(response, tools) return response["message"]["content"]

View File

@@ -9,7 +9,17 @@ from mem0.llms.base import LLMBase
class OpenAILLM(LLMBase): class OpenAILLM(LLMBase):
"""
A class to interact with OpenAI or OpenRouter APIs for generating responses using LLMs.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the OpenAILLM instance.
Args:
config (Optional[BaseLlmConfig]): Configuration for the LLM, including model, API key, and base URLs.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
@@ -24,57 +34,27 @@ class OpenAILLM(LLMBase):
) )
else: else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" base_url = (
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using OpenAI. Generates a response based on the provided messages using OpenAI or OpenRouter.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of message dictionaries containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -102,9 +82,6 @@ class OpenAILLM(LLMBase):
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools) return response.choices[0].message.content

View File

@@ -9,66 +9,45 @@ from mem0.llms.base import LLMBase
class OpenAIStructuredLLM(LLMBase): class OpenAIStructuredLLM(LLMBase):
"""
A class for interacting with OpenAI's structured language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the OpenAIStructuredLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model = "gpt-4o-2024-08-06" self.config.model = "gpt-4o-2024-08-06"
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" base_url = (
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools (list, optional): List of tools that the model can call.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using OpenAI. Generates a response using OpenAI based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -78,10 +57,6 @@ class OpenAIStructuredLLM(LLMBase):
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.beta.chat.completions.parse(**params) response = self.client.beta.chat.completions.parse(**params)
return response.choices[0].message.content
return self._parse_response(response, tools)

View File

@@ -5,14 +5,26 @@ from typing import Dict, List, Optional
try: try:
from together import Together from together import Together
except ImportError: except ImportError:
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.") raise ImportError(
"The 'together' library is required. Please install it using 'pip install together'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
class TogetherLLM(LLMBase): class TogetherLLM(LLMBase):
"""
A class for interacting with the TogetherAI language model using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None): def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the TogetherLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
@@ -21,54 +33,20 @@ class TogetherLLM(LLMBase):
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY") api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
self.client = Together(api_key=api_key) self.client = Together(api_key=api_key)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response( def generate_response(
self, self,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
response_format=None, response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None, ) -> str:
tool_choice: str = "auto",
):
""" """
Generate a response based on the given messages using TogetherAI. Generates a response using TogetherAI based on the provided messages.
Args: Args:
messages (list): List of message dicts containing 'role' and 'content'. messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (str or object, optional): Format of the response. Defaults to "text". response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns: Returns:
str: The generated response. str: The generated response from the model.
""" """
params = { params = {
"model": self.config.model, "model": self.config.model,
@@ -79,9 +57,6 @@ class TogetherLLM(LLMBase):
} }
if response_format: if response_format:
params["response_format"] = response_format params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params) response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools) return response.choices[0].message.content

View File

@@ -15,7 +15,11 @@ class XAILLM(LLMBase):
self.config.model = "grok-2-latest" self.config.model = "grok-2-latest"
api_key = self.config.api_key or os.getenv("XAI_API_KEY") api_key = self.config.api_key or os.getenv("XAI_API_KEY")
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1" base_url = (
self.config.xai_base_url
or os.getenv("XAI_API_BASE")
or "https://api.x.ai/v1"
)
self.client = OpenAI(api_key=api_key, base_url=base_url) self.client = OpenAI(api_key=api_key, base_url=base_url)
def generate_response(self, messages: List[Dict[str, str]], response_format=None): def generate_response(self, messages: List[Dict[str, str]], response_format=None):

View File

@@ -20,8 +20,10 @@ def mock_openai_client():
yield mock_client yield mock_client
def test_generate_response_without_tools(mock_openai_client): def test_generate_response(mock_openai_client):
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P) config = BaseLlmConfig(
model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
llm = AzureOpenAILLM(config) llm = AzureOpenAILLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@@ -29,67 +31,21 @@ def test_generate_response_without_tools(mock_openai_client):
] ]
mock_response = Mock() mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_openai_client.chat.completions.create.return_value = mock_response mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages) response = llm.generate_response(messages)
mock_openai_client.chat.completions.create.assert_called_once_with(
model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_openai_client):
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
llm = AzureOpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_openai_client.chat.completions.create.assert_called_once_with( mock_openai_client.chat.completions.create.assert_called_once_with(
model=MODEL, model=MODEL,
messages=messages, messages=messages,
temperature=TEMPERATURE, temperature=TEMPERATURE,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
top_p=TOP_P, top_p=TOP_P,
tools=tools,
tool_choice="auto",
) )
assert response == "I'm doing well, thank you for asking!"
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -128,4 +84,6 @@ def test_generate_with_http_proxies(default_headers):
api_version=None, api_version=None,
default_headers=default_headers, default_headers=default_headers,
) )
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000") mock_http_client.assert_called_once_with(
proxies="http://testproxy.mem0.net:8000"
)

View File

@@ -16,33 +16,47 @@ def mock_deepseek_client():
def test_deepseek_llm_base_url(): def test_deepseek_llm_base_url():
# case1: default config with deepseek official base url # case1: default config with deepseek official base url
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") config = BaseLlmConfig(
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
)
llm = DeepSeekLLM(config) llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == "https://api.deepseek.com" assert str(llm.client.base_url) == "https://api.deepseek.com"
# case2: with env variable DEEPSEEK_API_BASE # case2: with env variable DEEPSEEK_API_BASE
provider_base_url = "https://api.provider.com/v1/" provider_base_url = "https://api.provider.com/v1/"
os.environ["DEEPSEEK_API_BASE"] = provider_base_url os.environ["DEEPSEEK_API_BASE"] = provider_base_url
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") config = BaseLlmConfig(
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
)
llm = DeepSeekLLM(config) llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == provider_base_url assert str(llm.client.base_url) == provider_base_url
# case3: with config.deepseek_base_url # case3: with config.deepseek_base_url
config_base_url = "https://api.config.com/v1/" config_base_url = "https://api.config.com/v1/"
config = BaseLlmConfig( config = BaseLlmConfig(
model="deepseek-chat", model="deepseek-chat",
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
top_p=1.0, top_p=1.0,
api_key="api_key", api_key="api_key",
deepseek_base_url=config_base_url deepseek_base_url=config_base_url,
) )
llm = DeepSeekLLM(config) llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == config_base_url assert str(llm.client.base_url) == config_base_url
def test_generate_response_without_tools(mock_deepseek_client): def test_generate_response(mock_deepseek_client):
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(
model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0
)
llm = DeepSeekLLM(config) llm = DeepSeekLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@@ -50,64 +64,18 @@ def test_generate_response_without_tools(mock_deepseek_client):
] ]
mock_response = Mock() mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_deepseek_client.chat.completions.create.return_value = mock_response mock_deepseek_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages) response = llm.generate_response(messages)
mock_deepseek_client.chat.completions.create.assert_called_once_with( mock_deepseek_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_deepseek_client):
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0)
llm = DeepSeekLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_deepseek_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_deepseek_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto"
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -17,7 +17,9 @@ def mock_gemini_client():
def test_generate_response_without_tools(mock_gemini_client: Mock): def test_generate_response_without_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(
model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0
)
llm = GeminiLLM(config) llm = GeminiLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@@ -34,86 +36,14 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
mock_gemini_client.generate_content.assert_called_once_with( mock_gemini_client.generate_content.assert_called_once_with(
contents=[ contents=[
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, {
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
"role": "user",
},
{"parts": "Hello, how are you?", "role": "user"}, {"parts": "Hello, how are you?", "role": "user"},
], ],
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), generation_config=GenerationConfig(
tools=None, temperature=0.7, max_output_tokens=100, top_p=1.0
tool_config=content_types.to_tool_config(
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
), ),
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_tool_call = Mock()
mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."}
mock_part = Mock()
mock_part.function_call = mock_tool_call
mock_part.text = "I've added the memory for you."
mock_content = Mock()
mock_content.parts = [mock_part]
mock_message = Mock()
mock_message.content = mock_content
mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_gemini_client.generate_content.assert_called_once_with(
contents=[
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
],
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
tools=[
{
"function_declarations": [
{
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
}
]
}
],
tool_config=content_types.to_tool_config(
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
),
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -14,8 +14,10 @@ def mock_groq_client():
yield mock_client yield mock_client
def test_generate_response_without_tools(mock_groq_client): def test_generate_response(mock_groq_client):
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(
model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0
)
llm = GroqLLM(config) llm = GroqLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@@ -23,64 +25,18 @@ def test_generate_response_without_tools(mock_groq_client):
] ]
mock_response = Mock() mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_groq_client.chat.completions.create.return_value = mock_response mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages) response = llm.generate_response(messages)
mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_groq_client):
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GroqLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_groq_client.chat.completions.create.assert_called_once_with( mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192", model="llama3-70b-8192",
messages=messages, messages=messages,
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
top_p=1.0, top_p=1.0,
tools=tools,
tool_choice="auto",
) )
assert response == "I'm doing well, thank you for asking!"
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -13,17 +13,22 @@ def mock_litellm():
def test_generate_response_with_unsupported_model(mock_litellm): def test_generate_response_with_unsupported_model(mock_litellm):
config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1) config = BaseLlmConfig(
model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1
)
llm = litellm.LiteLLM(config) llm = litellm.LiteLLM(config)
messages = [{"role": "user", "content": "Hello"}] messages = [{"role": "user", "content": "Hello"}]
mock_litellm.supports_function_calling.return_value = False mock_litellm.supports_function_calling.return_value = False
with pytest.raises(ValueError, match="Model 'unsupported-model' in litellm does not support function calling."): with pytest.raises(
ValueError,
match="Model 'unsupported-model' in LiteLLM does not support function calling.",
):
llm.generate_response(messages) llm.generate_response(messages)
def test_generate_response_without_tools(mock_litellm): def test_generate_response(mock_litellm):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1) config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config) llm = litellm.LiteLLM(config)
messages = [ messages = [
@@ -32,7 +37,9 @@ def test_generate_response_without_tools(mock_litellm):
] ]
mock_response = Mock() mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_litellm.completion.return_value = mock_response mock_litellm.completion.return_value = mock_response
mock_litellm.supports_function_calling.return_value = True mock_litellm.supports_function_calling.return_value = True
@@ -42,50 +49,3 @@ def test_generate_response_without_tools(mock_litellm):
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_litellm):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_litellm.completion.return_value = mock_response
mock_litellm.supports_function_calling.return_value = True
response = llm.generate_response(messages, tools=tools)
mock_litellm.completion.assert_called_once_with(
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1, tools=tools, tool_choice="auto"
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -16,7 +16,9 @@ def mock_openai_client():
def test_openai_llm_base_url(): def test_openai_llm_base_url():
# case1: default config: with openai official base url # case1: default config: with openai official base url
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") config = BaseLlmConfig(
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
)
llm = OpenAILLM(config) llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == "https://api.openai.com/v1/" assert str(llm.client.base_url) == "https://api.openai.com/v1/"
@@ -24,7 +26,9 @@ def test_openai_llm_base_url():
# case2: with env variable OPENAI_API_BASE # case2: with env variable OPENAI_API_BASE
provider_base_url = "https://api.provider.com/v1" provider_base_url = "https://api.provider.com/v1"
os.environ["OPENAI_API_BASE"] = provider_base_url os.environ["OPENAI_API_BASE"] = provider_base_url
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") config = BaseLlmConfig(
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
)
llm = OpenAILLM(config) llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == provider_base_url + "/" assert str(llm.client.base_url) == provider_base_url + "/"
@@ -32,14 +36,19 @@ def test_openai_llm_base_url():
# case3: with config.openai_base_url # case3: with config.openai_base_url
config_base_url = "https://api.config.com/v1" config_base_url = "https://api.config.com/v1"
config = BaseLlmConfig( config = BaseLlmConfig(
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", openai_base_url=config_base_url model="gpt-4o",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
openai_base_url=config_base_url,
) )
llm = OpenAILLM(config) llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == config_base_url + "/" assert str(llm.client.base_url) == config_base_url + "/"
def test_generate_response_without_tools(mock_openai_client): def test_generate_response(mock_openai_client):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OpenAILLM(config) llm = OpenAILLM(config)
messages = [ messages = [
@@ -48,7 +57,9 @@ def test_generate_response_without_tools(mock_openai_client):
] ]
mock_response = Mock() mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_openai_client.chat.completions.create.return_value = mock_response mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages) response = llm.generate_response(messages)
@@ -57,49 +68,3 @@ def test_generate_response_without_tools(mock_openai_client):
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_openai_client):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_openai_client.chat.completions.create.assert_called_once_with(
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -14,8 +14,13 @@ def mock_together_client():
yield mock_client yield mock_client
def test_generate_response_without_tools(mock_together_client): def test_generate_response(mock_together_client):
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0.7,
max_tokens=100,
top_p=1.0,
)
llm = TogetherLLM(config) llm = TogetherLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@@ -23,64 +28,18 @@ def test_generate_response_without_tools(mock_together_client):
] ]
mock_response = Mock() mock_response = Mock()
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_together_client.chat.completions.create.return_value = mock_response mock_together_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages) response = llm.generate_response(messages)
mock_together_client.chat.completions.create.assert_called_once_with(
model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_together_client):
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
llm = TogetherLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_together_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_together_client.chat.completions.create.assert_called_once_with( mock_together_client.chat.completions.create.assert_called_once_with(
model="mistralai/Mixtral-8x7B-Instruct-v0.1", model="mistralai/Mixtral-8x7B-Instruct-v0.1",
messages=messages, messages=messages,
temperature=0.7, temperature=0.7,
max_tokens=100, max_tokens=100,
top_p=1.0, top_p=1.0,
tools=tools,
tool_choice="auto",
) )
assert response == "I'm doing well, thank you for asking!"
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}