From ee66e0c95423a359214d20e6acfbe1878501c668 Mon Sep 17 00:00:00 2001 From: Parshva Daftari <89991302+parshvadaftari@users.noreply.github.com> Date: Thu, 20 Mar 2025 00:09:00 +0530 Subject: [PATCH] Reverting the `tools` commit (#2404) --- mem0/llms/anthropic.py | 33 +++-- mem0/llms/aws_bedrock.py | 176 +++++++++++++++++++-------- mem0/llms/azure_openai.py | 81 +++++++----- mem0/llms/azure_openai_structured.py | 70 ++--------- mem0/llms/configs.py | 8 +- mem0/llms/deepseek.py | 67 +++++++--- mem0/llms/gemini.py | 132 +++++++++++++++----- mem0/llms/groq.py | 66 +++++++--- mem0/llms/litellm.py | 69 +++++++---- mem0/llms/ollama.py | 68 +++++++---- mem0/llms/openai.py | 67 ++++++---- mem0/llms/openai_structured.py | 57 +-------- mem0/llms/together.py | 65 +++++++--- mem0/llms/xai.py | 6 +- tests/llms/test_azure_openai.py | 64 ++++++++-- tests/llms/test_deepseek.py | 96 ++++++++++----- tests/llms/test_gemini_llm.py | 88 ++++++++++++-- tests/llms/test_groq.py | 60 +++++++-- tests/llms/test_litellm.py | 62 ++++++++-- tests/llms/test_openai.py | 67 +++++++--- tests/llms/test_together.py | 63 ++++++++-- 21 files changed, 990 insertions(+), 475 deletions(-) diff --git a/mem0/llms/anthropic.py b/mem0/llms/anthropic.py index 48a6da95..5f004ae8 100644 --- a/mem0/llms/anthropic.py +++ b/mem0/llms/anthropic.py @@ -4,26 +4,14 @@ from typing import Dict, List, Optional try: import anthropic except ImportError: - raise ImportError( - "The 'anthropic' library is required. Please install it using 'pip install anthropic'." - ) + raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase class AnthropicLLM(LLMBase): - """ - A class for interacting with Anthropic's Claude models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the AnthropicLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) if not self.config.model: @@ -35,17 +23,23 @@ class AnthropicLLM(LLMBase): def generate_response( self, messages: List[Dict[str, str]], - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response using Anthropic's Claude model based on the provided messages. + Generate a response based on the given messages using Anthropic. Args: - messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ - # Extract system message separately + # Separate system message from other messages system_message = "" filtered_messages = [] for message in messages: @@ -62,6 +56,9 @@ class AnthropicLLM(LLMBase): "max_tokens": self.config.max_tokens, "top_p": self.config.top_p, } + if tools: # TODO: Remove tools if no issues found with new memory addition logic + params["tools"] = tools + params["tool_choice"] = tool_choice response = self.client.messages.create(**params) return response.content[0].text diff --git a/mem0/llms/aws_bedrock.py b/mem0/llms/aws_bedrock.py index 021485e2..8d8bb01d 100644 --- a/mem0/llms/aws_bedrock.py +++ b/mem0/llms/aws_bedrock.py @@ -4,26 +4,14 @@ from typing import Any, Dict, List, Optional try: import boto3 except ImportError: - raise ImportError( - "The 'boto3' library is required. Please install it using 'pip install boto3'." - ) + raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase class AWSBedrockLLM(LLMBase): - """ - A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the AWS Bedrock LLM with the provided configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration object for the model. - """ super().__init__(config) if not self.config.model: @@ -37,29 +25,49 @@ class AWSBedrockLLM(LLMBase): def _format_messages(self, messages: List[Dict[str, str]]) -> str: """ - Formats a list of messages into a structured prompt for the model. + Formats a list of messages into the required prompt structure for the model. Args: - messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'. + messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message. + Each dictionary contains 'role' and 'content' keys. Returns: str: A formatted string combining all messages, structured with roles capitalized and separated by newlines. """ - formatted_messages = [ - f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages - ] + formatted_messages = [] + for message in messages: + role = message["role"].capitalize() + content = message["content"] + formatted_messages.append(f"\n\n{role}: {content}") + return "".join(formatted_messages) + "\n\nAssistant:" - def _parse_response(self, response) -> str: + def _parse_response(self, response, tools) -> str: """ - Extracts the generated response from the API response. + Process the response based on whether tools are used or not. Args: - response: The raw response from the AWS Bedrock API. + response: The raw response from API. + tools: The list of tools provided in the request. Returns: - str: The generated response text. + str or dict: The processed response. """ + if tools: + processed_response = {"tool_calls": []} + + if response["output"]["message"]["content"]: + for item in response["output"]["message"]["content"]: + if "toolUse" in item: + processed_response["tool_calls"].append( + { + "name": item["toolUse"]["name"], + "arguments": item["toolUse"]["input"], + } + ) + + return processed_response + response_body = json.loads(response["body"].read().decode()) return response_body.get("completion", "") @@ -68,21 +76,22 @@ class AWSBedrockLLM(LLMBase): provider: str, model: str, prompt: str, - model_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[Dict[str, Any]] = {}, ) -> Dict[str, Any]: """ - Prepares the input dictionary for the specified provider's model. + Prepares the input dictionary for the specified provider's model by mapping and renaming + keys in the input based on the provider's requirements. Args: - provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon"). - model (str): The model identifier. - prompt (str): The input prompt. - model_kwargs (Optional[Dict[str, Any]]): Additional model parameters. + provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon"). + model (str): The name or identifier of the model being used. + prompt (str): The text prompt to be processed by the model. + model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements. Returns: - Dict[str, Any]: The prepared input dictionary. + Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider. """ - model_kwargs = model_kwargs or {} + input_body = {"prompt": prompt, **model_kwargs} provider_mappings = { @@ -110,35 +119,102 @@ class AWSBedrockLLM(LLMBase): }, } input_body["textGenerationConfig"] = { - k: v - for k, v in input_body["textGenerationConfig"].items() - if v is not None + k: v for k, v in input_body["textGenerationConfig"].items() if v is not None } return input_body - def generate_response(self, messages: List[Dict[str, str]]) -> str: + def _convert_tool_format(self, original_tools): """ - Generates a response using AWS Bedrock based on the provided messages. + Converts a list of tools from their original format to a new standardized format. Args: - messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'. + original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details. Returns: - str: The generated response text. + list: A list of dictionaries representing the tools in the new standardized format. """ - prompt = self._format_messages(messages) - provider = self.config.model.split(".")[0] - input_body = self._prepare_input( - provider, self.config.model, prompt, self.model_kwargs - ) - body = json.dumps(input_body) + new_tools = [] - response = self.client.invoke_model( - body=body, - modelId=self.config.model, - accept="application/json", - contentType="application/json", - ) + for tool in original_tools: + if tool["type"] == "function": + function = tool["function"] + new_tool = { + "toolSpec": { + "name": function["name"], + "description": function["description"], + "inputSchema": { + "json": { + "type": "object", + "properties": {}, + "required": function["parameters"].get("required", []), + } + }, + } + } - return self._parse_response(response) + for prop, details in function["parameters"].get("properties", {}).items(): + new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = { + "type": details.get("type", "string"), + "description": details.get("description", ""), + } + + new_tools.append(new_tool) + + return new_tools + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using AWS Bedrock. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + + if tools: + # Use converse method when tools are provided + messages = [ + { + "role": "user", + "content": [{"text": message["content"]} for message in messages], + } + ] + inference_config = { + "temperature": self.model_kwargs["temperature"], + "maxTokens": self.model_kwargs["max_tokens_to_sample"], + "topP": self.model_kwargs["top_p"], + } + tools_config = {"tools": self._convert_tool_format(tools)} + + response = self.client.converse( + modelId=self.config.model, + messages=messages, + inferenceConfig=inference_config, + toolConfig=tools_config, + ) + else: + # Use invoke_model method when no tools are provided + prompt = self._format_messages(messages) + provider = self.model.split(".")[0] + input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs) + body = json.dumps(input_body) + + response = self.client.invoke_model( + body=body, + modelId=self.model, + accept="application/json", + contentType="application/json", + ) + + return self._parse_response(response, tools) diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py index 112ec8d0..3400b382 100644 --- a/mem0/llms/azure_openai.py +++ b/mem0/llms/azure_openai.py @@ -9,35 +9,17 @@ from mem0.llms.base import LLMBase class AzureOpenAILLM(LLMBase): - """ - A class for interacting with Azure OpenAI models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the AzureOpenAILLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) - # Ensure model name is set; it should match the Azure OpenAI deployment name. + # Model name should match the custom deployment name chosen for it. if not self.config.model: self.config.model = "gpt-4o" - api_key = self.config.azure_kwargs.api_key or os.getenv( - "LLM_AZURE_OPENAI_API_KEY" - ) - azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv( - "LLM_AZURE_DEPLOYMENT" - ) - azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv( - "LLM_AZURE_ENDPOINT" - ) - api_version = self.config.azure_kwargs.api_version or os.getenv( - "LLM_AZURE_API_VERSION" - ) + api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY") + azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT") + azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT") + api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION") default_headers = self.config.azure_kwargs.default_headers self.client = AzureOpenAI( @@ -49,20 +31,54 @@ class AzureOpenAILLM(LLMBase): default_headers=default_headers, ) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], - response_format: Optional[str] = None, - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response using Azure OpenAI based on the provided messages. + Generate a response based on the given messages using Azure OpenAI. Args: - messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. - response_format (Optional[str]): The desired format of the response. Defaults to None. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -71,8 +87,11 @@ class AzureOpenAILLM(LLMBase): "max_tokens": self.config.max_tokens, "top_p": self.config.top_p, } - if response_format: params["response_format"] = response_format + if tools: # TODO: Remove tools if no issues found with new memory addition logic + params["tools"] = tools + params["tool_choice"] = tool_choice + response = self.client.chat.completions.create(**params) - return response.choices[0].message.content + return self._parse_response(response, tools) diff --git a/mem0/llms/azure_openai_structured.py b/mem0/llms/azure_openai_structured.py index 9e07bf0b..29d82e65 100644 --- a/mem0/llms/azure_openai_structured.py +++ b/mem0/llms/azure_openai_structured.py @@ -9,38 +9,20 @@ from mem0.llms.base import LLMBase class AzureOpenAIStructuredLLM(LLMBase): - """ - A class for interacting with Azure OpenAI models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the AzureOpenAIStructuredLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) - # Ensure model name is set; it should match the Azure OpenAI deployment name. + # Model name should match the custom deployment name chosen for it. if not self.config.model: self.config.model = "gpt-4o-2024-08-06" - api_key = ( - os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key - ) - azure_deployment = ( - os.getenv("LLM_AZURE_DEPLOYMENT") - or self.config.azure_kwargs.azure_deployment - ) - azure_endpoint = ( - os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint - ) - api_version = ( - os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version - ) + api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key + azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment + azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint + api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version default_headers = self.config.azure_kwargs.default_headers + # Can display a warning if API version is of model and api-version self.client = AzureOpenAI( azure_deployment=azure_deployment, azure_endpoint=azure_endpoint, @@ -50,52 +32,20 @@ class AzureOpenAIStructuredLLM(LLMBase): default_headers=default_headers, ) - def _parse_response(self, response, tools): - """ - Process the response based on whether tools are used or not. - Args: - response: The raw response from API. - tools: The list of tools provided in the request. - Returns: - str or dict: The processed response. - """ - if tools: - processed_response = { - "content": response.choices[0].message.content, - "tool_calls": [], - } - - if response.choices[0].message.tool_calls: - for tool_call in response.choices[0].message.tool_calls: - processed_response["tool_calls"].append( - { - "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), - } - ) - - return processed_response - else: - return response.choices[0].message.content - def generate_response( self, messages: List[Dict[str, str]], response_format: Optional[str] = None, - tools: Optional[List[Dict]] = None, - tool_choice: str = "auto", ) -> str: """ - Generates a response using Azure OpenAI based on the provided messages. + Generate a response based on the given messages using Azure OpenAI. Args: messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. response_format (Optional[str]): The desired format of the response. Defaults to None. - tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key. - tool_choice (str): The choice of tool to use. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -104,9 +54,11 @@ class AzureOpenAIStructuredLLM(LLMBase): "max_tokens": self.config.max_tokens, "top_p": self.config.top_p, } - if response_format: params["response_format"] = response_format + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice if tools: params["tools"] = tools diff --git a/mem0/llms/configs.py b/mem0/llms/configs.py index ef89551b..40287ce1 100644 --- a/mem0/llms/configs.py +++ b/mem0/llms/configs.py @@ -4,12 +4,8 @@ from pydantic import BaseModel, Field, field_validator class LlmConfig(BaseModel): - provider: str = Field( - description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai" - ) - config: Optional[dict] = Field( - description="Configuration for the specific LLM", default={} - ) + provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai") + config: Optional[dict] = Field(description="Configuration for the specific LLM", default={}) @field_validator("config") def validate_config(cls, v, values): diff --git a/mem0/llms/deepseek.py b/mem0/llms/deepseek.py index 7f80f755..46a805f0 100644 --- a/mem0/llms/deepseek.py +++ b/mem0/llms/deepseek.py @@ -9,42 +9,64 @@ from mem0.llms.base import LLMBase class DeepSeekLLM(LLMBase): - """ - A class for interacting with DeepSeek's language models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the DeepSeekLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) if not self.config.model: self.config.model = "deepseek-chat" api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY") - base_url = ( - self.config.deepseek_base_url - or os.getenv("DEEPSEEK_API_BASE") - or "https://api.deepseek.com" - ) + base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com" self.client = OpenAI(api_key=api_key, base_url=base_url) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response using DeepSeek based on the provided messages. + Generate a response based on the given messages using DeepSeek. Args: - messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -53,5 +75,10 @@ class DeepSeekLLM(LLMBase): "max_tokens": self.config.max_tokens, "top_p": self.config.top_p, } + + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice + response = self.client.chat.completions.create(**params) - return response.choices[0].message.content + return self._parse_response(response, tools) diff --git a/mem0/llms/gemini.py b/mem0/llms/gemini.py index 3285652b..7881cf05 100644 --- a/mem0/llms/gemini.py +++ b/mem0/llms/gemini.py @@ -3,7 +3,8 @@ from typing import Dict, List, Optional try: import google.generativeai as genai - from google.generativeai import GenerativeModel + from google.generativeai import GenerativeModel, protos + from google.generativeai.types import content_types except ImportError: raise ImportError( "The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'." @@ -14,17 +15,7 @@ from mem0.llms.base import LLMBase class GeminiLLM(LLMBase): - """ - A wrapper for Google's Gemini language model, integrating it with the LLMBase class. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the Gemini LLM with the provided configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration object for the model. - """ super().__init__(config) if not self.config.model: @@ -34,25 +25,51 @@ class GeminiLLM(LLMBase): genai.configure(api_key=api_key) self.client = GenerativeModel(model_name=self.config.model) - def _reformat_messages( - self, messages: List[Dict[str, str]] - ) -> List[Dict[str, str]]: + def _parse_response(self, response, tools): """ - Reformats messages to match the Gemini API's expected structure. + Process the response based on whether tools are used or not. Args: - messages (List[Dict[str, str]]): A list of messages with 'role' and 'content' keys. + response: The raw response from API. + tools: The list of tools provided in the request. Returns: - List[Dict[str, str]]: Reformatted messages in the required format. + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": (content if (content := response.candidates[0].content.parts[0].text) else None), + "tool_calls": [], + } + + for part in response.candidates[0].content.parts: + if fn := part.function_call: + if isinstance(fn, protos.FunctionCall): + fn_call = type(fn).to_dict(fn) + processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]}) + continue + processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args}) + + return processed_response + else: + return response.candidates[0].content.parts[0].text + + def _reformat_messages(self, messages: List[Dict[str, str]]): + """ + Reformat messages for Gemini. + + Args: + messages: The list of messages provided in the request. + + Returns: + list: The list of messages in the required format. """ new_messages = [] for message in messages: if message["role"] == "system": - content = ( - "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] - ) + content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] + else: content = message["content"] @@ -65,33 +82,90 @@ class GeminiLLM(LLMBase): return new_messages - def generate_response( - self, messages: List[Dict[str, str]], response_format: Optional[Dict] = None - ) -> str: + def _reformat_tools(self, tools: Optional[List[Dict]]): """ - Generates a response from Gemini based on the given conversation history. + Reformat tools for Gemini. Args: - messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'. - response_format (Optional[Dict]): Specifies the response format (e.g., JSON schema). + tools: The list of tools provided in the request. Returns: - str: The generated response as text. + list: The list of tools in the required format. """ + + def remove_additional_properties(data): + """Recursively removes 'additionalProperties' from nested dictionaries.""" + + if isinstance(data, dict): + filtered_dict = { + key: remove_additional_properties(value) + for key, value in data.items() + if not (key == "additionalProperties") + } + return filtered_dict + else: + return data + + new_tools = [] + if tools: + for tool in tools: + func = tool["function"].copy() + new_tools.append({"function_declarations": [remove_additional_properties(func)]}) + + # TODO: temporarily ignore it to pass tests, will come back to update according to standards later. + # return content_types.to_function_library(new_tools) + + return new_tools + else: + return None + + def generate_response( + self, + messages: List[Dict[str, str]], + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): + """ + Generate a response based on the given messages using Gemini. + + Args: + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format for the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". + + Returns: + str: The generated response. + """ + params = { "temperature": self.config.temperature, "max_output_tokens": self.config.max_tokens, "top_p": self.config.top_p, } - if response_format and response_format.get("type") == "json_object": + if response_format is not None and response_format["type"] == "json_object": params["response_mime_type"] = "application/json" if "schema" in response_format: params["response_schema"] = response_format["schema"] + if tool_choice: + tool_config = content_types.to_tool_config( + { + "function_calling_config": { + "mode": tool_choice, + "allowed_function_names": ( + [tool["function"]["name"] for tool in tools] if tool_choice == "any" else None + ), + } + } + ) response = self.client.generate_content( contents=self._reformat_messages(messages), + tools=self._reformat_tools(tools), generation_config=genai.GenerationConfig(**params), + tool_config=tool_config, ) - return response.candidates[0].content.parts[0].text + return self._parse_response(response, tools) diff --git a/mem0/llms/groq.py b/mem0/llms/groq.py index 31992488..38a1c8a0 100644 --- a/mem0/llms/groq.py +++ b/mem0/llms/groq.py @@ -5,26 +5,14 @@ from typing import Dict, List, Optional try: from groq import Groq except ImportError: - raise ImportError( - "The 'groq' library is required. Please install it using 'pip install groq'." - ) + raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase class GroqLLM(LLMBase): - """ - A class for interacting with Groq's language models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the GroqLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) if not self.config.model: @@ -33,20 +21,54 @@ class GroqLLM(LLMBase): api_key = self.config.api_key or os.getenv("GROQ_API_KEY") self.client = Groq(api_key=api_key) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], - response_format: Optional[str] = None, - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response using Groq based on the provided messages. + Generate a response based on the given messages using Groq. Args: - messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. - response_format (Optional[str]): The desired format of the response. Defaults to None. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -57,5 +79,9 @@ class GroqLLM(LLMBase): } if response_format: params["response_format"] = response_format + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice + response = self.client.chat.completions.create(**params) - return response.choices[0].message.content + return self._parse_response(response, tools) \ No newline at end of file diff --git a/mem0/llms/litellm.py b/mem0/llms/litellm.py index 3d9368db..d5896ff8 100644 --- a/mem0/llms/litellm.py +++ b/mem0/llms/litellm.py @@ -4,50 +4,70 @@ from typing import Dict, List, Optional try: import litellm except ImportError: - raise ImportError( - "The 'litellm' library is required. Please install it using 'pip install litellm'." - ) + raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase class LiteLLM(LLMBase): - """ - A class for interacting with LiteLLM's language models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the LiteLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) if not self.config.model: self.config.model = "gpt-4o-mini" + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], - response_format: Optional[str] = None, - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response using LiteLLM based on the provided messages. + Generate a response based on the given messages using Litellm. Args: - messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. - response_format (Optional[str]): The desired format of the response. Defaults to None. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ if not litellm.supports_function_calling(self.config.model): - raise ValueError( - f"Model '{self.config.model}' in LiteLLM does not support function calling." - ) + raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.") params = { "model": self.config.model, @@ -58,6 +78,9 @@ class LiteLLM(LLMBase): } if response_format: params["response_format"] = response_format + if tools: # TODO: Remove tools if no issues found with new memory addition logic + params["tools"] = tools + params["tool_choice"] = tool_choice response = litellm.completion(**params) - return response.choices[0].message.content + return self._parse_response(response, tools) diff --git a/mem0/llms/ollama.py b/mem0/llms/ollama.py index 5fe31001..54d8b719 100644 --- a/mem0/llms/ollama.py +++ b/mem0/llms/ollama.py @@ -3,56 +3,77 @@ from typing import Dict, List, Optional try: from ollama import Client except ImportError: - raise ImportError( - "The 'ollama' library is required. Please install it using 'pip install ollama'." - ) + raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase class OllamaLLM(LLMBase): - """ - A class for interacting with Ollama's language models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the OllamaLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) if not self.config.model: self.config.model = "llama3.1:70b" - self.client = Client(host=self.config.ollama_base_url) self._ensure_model_exists() def _ensure_model_exists(self): """ - Ensures the specified model exists locally. If not, pulls it from Ollama. + Ensure the specified model exists locally. If not, pull it from Ollama. """ local_models = self.client.list()["models"] if not any(model.get("name") == self.config.model for model in local_models): self.client.pull(self.config.model) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response["message"]["content"], + "tool_calls": [], + } + + if response["message"].get("tool_calls"): + for tool_call in response["message"]["tool_calls"]: + processed_response["tool_calls"].append( + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + ) + + return processed_response + else: + return response["message"]["content"] + def generate_response( self, messages: List[Dict[str, str]], - response_format: Optional[str] = None, - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response using Ollama based on the provided messages. + Generate a response based on the given messages using OpenAI. Args: - messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. - response_format (Optional[str]): The desired format of the response. Defaults to None. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -66,5 +87,8 @@ class OllamaLLM(LLMBase): if response_format: params["format"] = "json" + if tools: + params["tools"] = tools + response = self.client.chat(**params) - return response["message"]["content"] + return self._parse_response(response, tools) diff --git a/mem0/llms/openai.py b/mem0/llms/openai.py index 08a8d654..a9c302f8 100644 --- a/mem0/llms/openai.py +++ b/mem0/llms/openai.py @@ -9,17 +9,7 @@ from mem0.llms.base import LLMBase class OpenAILLM(LLMBase): - """ - A class to interact with OpenAI or OpenRouter APIs for generating responses using LLMs. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the OpenAILLM instance. - - Args: - config (Optional[BaseLlmConfig]): Configuration for the LLM, including model, API key, and base URLs. - """ super().__init__(config) if not self.config.model: @@ -34,27 +24,57 @@ class OpenAILLM(LLMBase): ) else: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - base_url = ( - self.config.openai_base_url - or os.getenv("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) + base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], - response_format: Optional[str] = None, - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response based on the provided messages using OpenAI or OpenRouter. + Generate a response based on the given messages using OpenAI. Args: - messages (List[Dict[str, str]]): A list of message dictionaries containing 'role' and 'content'. - response_format (Optional[str]): The format of the response. Defaults to None. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -82,6 +102,9 @@ class OpenAILLM(LLMBase): if response_format: params["response_format"] = response_format + if tools: # TODO: Remove tools if no issues found with new memory addition logic + params["tools"] = tools + params["tool_choice"] = tool_choice response = self.client.chat.completions.create(**params) - return response.choices[0].message.content + return self._parse_response(response, tools) diff --git a/mem0/llms/openai_structured.py b/mem0/llms/openai_structured.py index 08c2edfb..dd381c59 100644 --- a/mem0/llms/openai_structured.py +++ b/mem0/llms/openai_structured.py @@ -9,78 +9,31 @@ from mem0.llms.base import LLMBase class OpenAIStructuredLLM(LLMBase): - """ - A class for interacting with OpenAI's structured language models using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the OpenAIStructuredLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) if not self.config.model: self.config.model = "gpt-4o-2024-08-06" api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - base_url = ( - self.config.openai_base_url - or os.getenv("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) + base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) - def _parse_response(self, response, tools): - """ - Process the response based on whether tools are used or not. - Args: - response: The raw response from API. - tools (list, optional): List of tools that the model can call. - Returns: - str or dict: The processed response. - """ - - if tools: - processed_response = { - "content": response.choices[0].message.content, - "tool_calls": [], - } - - if response.choices[0].message.tool_calls: - for tool_call in response.choices[0].message.tool_calls: - processed_response["tool_calls"].append( - { - "name": tool_call.function.name, - "arguments": json.loads(tool_call.function.arguments), - } - ) - - return processed_response - - else: - return response.choices[0].message.content - def generate_response( self, messages: List[Dict[str, str]], response_format: Optional[str] = None, - tools: Optional[List[Dict]] = None, - tool_choice: str = "auto", ) -> str: """ - Generates a response using OpenAI based on the provided messages. + Generate a response based on the given messages using OpenAI. Args: messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. response_format (Optional[str]): The desired format of the response. Defaults to None. - tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key. - tool_choice (str): The choice of tool to use. Defaults to "auto". + Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -95,4 +48,4 @@ class OpenAIStructuredLLM(LLMBase): params["tool_choice"] = tool_choice response = self.client.beta.chat.completions.parse(**params) - return self._parse_response(response, tools) + return response.choices[0].message.content diff --git a/mem0/llms/together.py b/mem0/llms/together.py index 2e794d44..922a30d2 100644 --- a/mem0/llms/together.py +++ b/mem0/llms/together.py @@ -5,26 +5,14 @@ from typing import Dict, List, Optional try: from together import Together except ImportError: - raise ImportError( - "The 'together' library is required. Please install it using 'pip install together'." - ) + raise ImportError("The 'together' library is required. Please install it using 'pip install together'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase class TogetherLLM(LLMBase): - """ - A class for interacting with the TogetherAI language model using the specified configuration. - """ - def __init__(self, config: Optional[BaseLlmConfig] = None): - """ - Initializes the TogetherLLM instance with the given configuration. - - Args: - config (Optional[BaseLlmConfig]): Configuration settings for the language model. - """ super().__init__(config) if not self.config.model: @@ -33,20 +21,54 @@ class TogetherLLM(LLMBase): api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY") self.client = Together(api_key=api_key) + def _parse_response(self, response, tools): + """ + Process the response based on whether tools are used or not. + + Args: + response: The raw response from API. + tools: The list of tools provided in the request. + + Returns: + str or dict: The processed response. + """ + if tools: + processed_response = { + "content": response.choices[0].message.content, + "tool_calls": [], + } + + if response.choices[0].message.tool_calls: + for tool_call in response.choices[0].message.tool_calls: + processed_response["tool_calls"].append( + { + "name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + + return processed_response + else: + return response.choices[0].message.content + def generate_response( self, messages: List[Dict[str, str]], - response_format: Optional[str] = None, - ) -> str: + response_format=None, + tools: Optional[List[Dict]] = None, + tool_choice: str = "auto", + ): """ - Generates a response using TogetherAI based on the provided messages. + Generate a response based on the given messages using TogetherAI. Args: - messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key. - response_format (Optional[str]): The desired format of the response. Defaults to None. + messages (list): List of message dicts containing 'role' and 'content'. + response_format (str or object, optional): Format of the response. Defaults to "text". + tools (list, optional): List of tools that the model can call. Defaults to None. + tool_choice (str, optional): Tool choice method. Defaults to "auto". Returns: - str: The generated response from the model. + str: The generated response. """ params = { "model": self.config.model, @@ -57,6 +79,9 @@ class TogetherLLM(LLMBase): } if response_format: params["response_format"] = response_format + if tools: # TODO: Remove tools if no issues found with new memory addition logic + params["tools"] = tools + params["tool_choice"] = tool_choice response = self.client.chat.completions.create(**params) - return response.choices[0].message.content + return self._parse_response(response, tools) diff --git a/mem0/llms/xai.py b/mem0/llms/xai.py index 1c85ba0d..5309eb57 100644 --- a/mem0/llms/xai.py +++ b/mem0/llms/xai.py @@ -15,11 +15,7 @@ class XAILLM(LLMBase): self.config.model = "grok-2-latest" api_key = self.config.api_key or os.getenv("XAI_API_KEY") - base_url = ( - self.config.xai_base_url - or os.getenv("XAI_API_BASE") - or "https://api.x.ai/v1" - ) + base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) def generate_response(self, messages: List[Dict[str, str]], response_format=None): diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py index 19091fdf..77cb88e9 100644 --- a/tests/llms/test_azure_openai.py +++ b/tests/llms/test_azure_openai.py @@ -20,10 +20,8 @@ def mock_openai_client(): yield mock_client -def test_generate_response(mock_openai_client): - config = BaseLlmConfig( - model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P - ) +def test_generate_response_without_tools(mock_openai_client): + config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P) llm = AzureOpenAILLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -31,21 +29,67 @@ def test_generate_response(mock_openai_client): ] mock_response = Mock() - mock_response.choices = [ - Mock(message=Mock(content="I'm doing well, thank you for asking!")) - ] + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_openai_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages) + mock_openai_client.chat.completions.create.assert_called_once_with( + model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P + ) + assert response == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_openai_client): + config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P) + llm = AzureOpenAILLM(config) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] + mock_openai_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + mock_openai_client.chat.completions.create.assert_called_once_with( model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P, + tools=tools, + tool_choice="auto", ) - assert response == "I'm doing well, thank you for asking!" + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} @pytest.mark.parametrize( @@ -84,6 +128,4 @@ def test_generate_with_http_proxies(default_headers): api_version=None, default_headers=default_headers, ) - mock_http_client.assert_called_once_with( - proxies="http://testproxy.mem0.net:8000" - ) + mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000") \ No newline at end of file diff --git a/tests/llms/test_deepseek.py b/tests/llms/test_deepseek.py index 2d6469fe..47e60cd3 100644 --- a/tests/llms/test_deepseek.py +++ b/tests/llms/test_deepseek.py @@ -16,47 +16,33 @@ def mock_deepseek_client(): def test_deepseek_llm_base_url(): # case1: default config with deepseek official base url - config = BaseLlmConfig( - model="deepseek-chat", - temperature=0.7, - max_tokens=100, - top_p=1.0, - api_key="api_key", - ) + config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") llm = DeepSeekLLM(config) assert str(llm.client.base_url) == "https://api.deepseek.com" # case2: with env variable DEEPSEEK_API_BASE provider_base_url = "https://api.provider.com/v1/" os.environ["DEEPSEEK_API_BASE"] = provider_base_url - config = BaseLlmConfig( - model="deepseek-chat", - temperature=0.7, - max_tokens=100, - top_p=1.0, - api_key="api_key", - ) + config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") llm = DeepSeekLLM(config) assert str(llm.client.base_url) == provider_base_url # case3: with config.deepseek_base_url config_base_url = "https://api.config.com/v1/" config = BaseLlmConfig( - model="deepseek-chat", - temperature=0.7, - max_tokens=100, - top_p=1.0, - api_key="api_key", - deepseek_base_url=config_base_url, + model="deepseek-chat", + temperature=0.7, + max_tokens=100, + top_p=1.0, + api_key="api_key", + deepseek_base_url=config_base_url ) llm = DeepSeekLLM(config) assert str(llm.client.base_url) == config_base_url -def test_generate_response(mock_deepseek_client): - config = BaseLlmConfig( - model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0 - ) +def test_generate_response_without_tools(mock_deepseek_client): + config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0) llm = DeepSeekLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -64,18 +50,64 @@ def test_generate_response(mock_deepseek_client): ] mock_response = Mock() - mock_response.choices = [ - Mock(message=Mock(content="I'm doing well, thank you for asking!")) - ] + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_deepseek_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages) mock_deepseek_client.chat.completions.create.assert_called_once_with( - model="deepseek-chat", - messages=messages, - temperature=0.7, - max_tokens=100, - top_p=1.0, + model="deepseek-chat", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_deepseek_client): + config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0) + llm = DeepSeekLLM(config) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] + mock_deepseek_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + + mock_deepseek_client.chat.completions.create.assert_called_once_with( + model="deepseek-chat", + messages=messages, + temperature=0.7, + max_tokens=100, + top_p=1.0, + tools=tools, + tool_choice="auto" + ) + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} \ No newline at end of file diff --git a/tests/llms/test_gemini_llm.py b/tests/llms/test_gemini_llm.py index d6ea76fd..ffdec4fb 100644 --- a/tests/llms/test_gemini_llm.py +++ b/tests/llms/test_gemini_llm.py @@ -17,9 +17,7 @@ def mock_gemini_client(): def test_generate_response_without_tools(mock_gemini_client: Mock): - config = BaseLlmConfig( - model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0 - ) + config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) llm = GeminiLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -36,14 +34,86 @@ def test_generate_response_without_tools(mock_gemini_client: Mock): mock_gemini_client.generate_content.assert_called_once_with( contents=[ - { - "parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", - "role": "user", - }, + {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, {"parts": "Hello, how are you?", "role": "user"}, ], - generation_config=GenerationConfig( - temperature=0.7, max_output_tokens=100, top_p=1.0 + generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), + tools=None, + tool_config=content_types.to_tool_config( + {"function_calling_config": {"mode": "auto", "allowed_function_names": None}} ), ) assert response == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_gemini_client: Mock): + config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) + llm = GeminiLLM(config) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + }, + } + ] + + mock_tool_call = Mock() + mock_tool_call.name = "add_memory" + mock_tool_call.args = {"data": "Today is a sunny day."} + + mock_part = Mock() + mock_part.function_call = mock_tool_call + mock_part.text = "I've added the memory for you." + + mock_content = Mock() + mock_content.parts = [mock_part] + + mock_message = Mock() + mock_message.content = mock_content + + mock_response = Mock(candidates=[mock_message]) + mock_gemini_client.generate_content.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + + mock_gemini_client.generate_content.assert_called_once_with( + contents=[ + {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, + {"parts": "Add a new memory: Today is a sunny day.", "role": "user"}, + ], + generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), + tools=[ + { + "function_declarations": [ + { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + } + ] + } + ], + tool_config=content_types.to_tool_config( + {"function_calling_config": {"mode": "auto", "allowed_function_names": None}} + ), + ) + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_groq.py b/tests/llms/test_groq.py index 9e5ceb32..288b37f8 100644 --- a/tests/llms/test_groq.py +++ b/tests/llms/test_groq.py @@ -14,10 +14,8 @@ def mock_groq_client(): yield mock_client -def test_generate_response(mock_groq_client): - config = BaseLlmConfig( - model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0 - ) +def test_generate_response_without_tools(mock_groq_client): + config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0) llm = GroqLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -25,18 +23,64 @@ def test_generate_response(mock_groq_client): ] mock_response = Mock() - mock_response.choices = [ - Mock(message=Mock(content="I'm doing well, thank you for asking!")) - ] + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_groq_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages) + mock_groq_client.chat.completions.create.assert_called_once_with( + model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 + ) + assert response == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_groq_client): + config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0) + llm = GroqLLM(config) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] + mock_groq_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + mock_groq_client.chat.completions.create.assert_called_once_with( model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, + tools=tools, + tool_choice="auto", ) - assert response == "I'm doing well, thank you for asking!" + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_litellm.py b/tests/llms/test_litellm.py index 58c638a6..d7be93c9 100644 --- a/tests/llms/test_litellm.py +++ b/tests/llms/test_litellm.py @@ -13,22 +13,17 @@ def mock_litellm(): def test_generate_response_with_unsupported_model(mock_litellm): - config = BaseLlmConfig( - model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1 - ) + config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1) llm = litellm.LiteLLM(config) messages = [{"role": "user", "content": "Hello"}] mock_litellm.supports_function_calling.return_value = False - with pytest.raises( - ValueError, - match="Model 'unsupported-model' in LiteLLM does not support function calling.", - ): + with pytest.raises(ValueError, match="Model 'unsupported-model' in litellm does not support function calling."): llm.generate_response(messages) -def test_generate_response(mock_litellm): +def test_generate_response_without_tools(mock_litellm): config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1) llm = litellm.LiteLLM(config) messages = [ @@ -37,9 +32,7 @@ def test_generate_response(mock_litellm): ] mock_response = Mock() - mock_response.choices = [ - Mock(message=Mock(content="I'm doing well, thank you for asking!")) - ] + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_litellm.completion.return_value = mock_response mock_litellm.supports_function_calling.return_value = True @@ -49,3 +42,50 @@ def test_generate_response(mock_litellm): model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_litellm): + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1) + llm = litellm.LiteLLM(config) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] + mock_litellm.completion.return_value = mock_response + mock_litellm.supports_function_calling.return_value = True + + response = llm.generate_response(messages, tools=tools) + + mock_litellm.completion.assert_called_once_with( + model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1, tools=tools, tool_choice="auto" + ) + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index bd86e718..42f8aa5b 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -16,9 +16,7 @@ def mock_openai_client(): def test_openai_llm_base_url(): # case1: default config: with openai official base url - config = BaseLlmConfig( - model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key" - ) + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") llm = OpenAILLM(config) # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash assert str(llm.client.base_url) == "https://api.openai.com/v1/" @@ -26,9 +24,7 @@ def test_openai_llm_base_url(): # case2: with env variable OPENAI_API_BASE provider_base_url = "https://api.provider.com/v1" os.environ["OPENAI_API_BASE"] = provider_base_url - config = BaseLlmConfig( - model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key" - ) + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key") llm = OpenAILLM(config) # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash assert str(llm.client.base_url) == provider_base_url + "/" @@ -36,19 +32,14 @@ def test_openai_llm_base_url(): # case3: with config.openai_base_url config_base_url = "https://api.config.com/v1" config = BaseLlmConfig( - model="gpt-4o", - temperature=0.7, - max_tokens=100, - top_p=1.0, - api_key="api_key", - openai_base_url=config_base_url, + model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", openai_base_url=config_base_url ) llm = OpenAILLM(config) # Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash assert str(llm.client.base_url) == config_base_url + "/" -def test_generate_response(mock_openai_client): +def test_generate_response_without_tools(mock_openai_client): config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) llm = OpenAILLM(config) messages = [ @@ -57,9 +48,7 @@ def test_generate_response(mock_openai_client): ] mock_response = Mock() - mock_response.choices = [ - Mock(message=Mock(content="I'm doing well, thank you for asking!")) - ] + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_openai_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages) @@ -68,3 +57,49 @@ def test_generate_response(mock_openai_client): model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 ) assert response == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_openai_client): + config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0) + llm = OpenAILLM(config) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] + mock_openai_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + + mock_openai_client.chat.completions.create.assert_called_once_with( + model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto" + ) + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} diff --git a/tests/llms/test_together.py b/tests/llms/test_together.py index 67d9a6c3..7c59ee41 100644 --- a/tests/llms/test_together.py +++ b/tests/llms/test_together.py @@ -14,13 +14,8 @@ def mock_together_client(): yield mock_client -def test_generate_response(mock_together_client): - config = BaseLlmConfig( - model="mistralai/Mixtral-8x7B-Instruct-v0.1", - temperature=0.7, - max_tokens=100, - top_p=1.0, - ) +def test_generate_response_without_tools(mock_together_client): + config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0) llm = TogetherLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -28,18 +23,64 @@ def test_generate_response(mock_together_client): ] mock_response = Mock() - mock_response.choices = [ - Mock(message=Mock(content="I'm doing well, thank you for asking!")) - ] + mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))] mock_together_client.chat.completions.create.return_value = mock_response response = llm.generate_response(messages) + mock_together_client.chat.completions.create.assert_called_once_with( + model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0 + ) + assert response == "I'm doing well, thank you for asking!" + + +def test_generate_response_with_tools(mock_together_client): + config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0) + llm = TogetherLLM(config) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Add a new memory: Today is a sunny day."}, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add_memory", + "description": "Add a memory", + "parameters": { + "type": "object", + "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, + "required": ["data"], + }, + }, + } + ] + + mock_response = Mock() + mock_message = Mock() + mock_message.content = "I've added the memory for you." + + mock_tool_call = Mock() + mock_tool_call.function.name = "add_memory" + mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}' + + mock_message.tool_calls = [mock_tool_call] + mock_response.choices = [Mock(message=mock_message)] + mock_together_client.chat.completions.create.return_value = mock_response + + response = llm.generate_response(messages, tools=tools) + mock_together_client.chat.completions.create.assert_called_once_with( model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, + tools=tools, + tool_choice="auto", ) - assert response == "I'm doing well, thank you for asking!" + + assert response["content"] == "I've added the memory for you." + assert len(response["tool_calls"]) == 1 + assert response["tool_calls"][0]["name"] == "add_memory" + assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}