Reverting the tools commit (#2404)

This commit is contained in:
Parshva Daftari
2025-03-20 00:09:00 +05:30
committed by GitHub
parent 1aed611539
commit ee66e0c954
21 changed files with 990 additions and 475 deletions

View File

@@ -4,26 +4,14 @@ from typing import Dict, List, Optional
try:
import anthropic
except ImportError:
raise ImportError(
"The 'anthropic' library is required. Please install it using 'pip install anthropic'."
)
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AnthropicLLM(LLMBase):
"""
A class for interacting with Anthropic's Claude models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AnthropicLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
if not self.config.model:
@@ -35,17 +23,23 @@ class AnthropicLLM(LLMBase):
def generate_response(
self,
messages: List[Dict[str, str]],
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response using Anthropic's Claude model based on the provided messages.
Generate a response based on the given messages using Anthropic.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
# Extract system message separately
# Separate system message from other messages
system_message = ""
filtered_messages = []
for message in messages:
@@ -62,6 +56,9 @@ class AnthropicLLM(LLMBase):
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.messages.create(**params)
return response.content[0].text

View File

@@ -4,26 +4,14 @@ from typing import Any, Dict, List, Optional
try:
import boto3
except ImportError:
raise ImportError(
"The 'boto3' library is required. Please install it using 'pip install boto3'."
)
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AWSBedrockLLM(LLMBase):
"""
A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AWS Bedrock LLM with the provided configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration object for the model.
"""
super().__init__(config)
if not self.config.model:
@@ -37,29 +25,49 @@ class AWSBedrockLLM(LLMBase):
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""
Formats a list of messages into a structured prompt for the model.
Formats a list of messages into the required prompt structure for the model.
Args:
messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
Each dictionary contains 'role' and 'content' keys.
Returns:
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
"""
formatted_messages = [
f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages
]
formatted_messages = []
for message in messages:
role = message["role"].capitalize()
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response) -> str:
def _parse_response(self, response, tools) -> str:
"""
Extracts the generated response from the API response.
Process the response based on whether tools are used or not.
Args:
response: The raw response from the AWS Bedrock API.
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str: The generated response text.
str or dict: The processed response.
"""
if tools:
processed_response = {"tool_calls": []}
if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]:
if "toolUse" in item:
processed_response["tool_calls"].append(
{
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"],
}
)
return processed_response
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
@@ -68,21 +76,22 @@ class AWSBedrockLLM(LLMBase):
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Prepares the input dictionary for the specified provider's model.
Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements.
Args:
provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The model identifier.
prompt (str): The input prompt.
model_kwargs (Optional[Dict[str, Any]]): Additional model parameters.
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Returns:
Dict[str, Any]: The prepared input dictionary.
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
"""
model_kwargs = model_kwargs or {}
input_body = {"prompt": prompt, **model_kwargs}
provider_mappings = {
@@ -110,35 +119,102 @@ class AWSBedrockLLM(LLMBase):
},
}
input_body["textGenerationConfig"] = {
k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
}
return input_body
def generate_response(self, messages: List[Dict[str, str]]) -> str:
def _convert_tool_format(self, original_tools):
"""
Generates a response using AWS Bedrock based on the provided messages.
Converts a list of tools from their original format to a new standardized format.
Args:
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details.
Returns:
str: The generated response text.
list: A list of dictionaries representing the tools in the new standardized format.
"""
prompt = self._format_messages(messages)
provider = self.config.model.split(".")[0]
input_body = self._prepare_input(
provider, self.config.model, prompt, self.model_kwargs
)
body = json.dumps(input_body)
new_tools = []
response = self.client.invoke_model(
body=body,
modelId=self.config.model,
accept="application/json",
contentType="application/json",
)
for tool in original_tools:
if tool["type"] == "function":
function = tool["function"]
new_tool = {
"toolSpec": {
"name": function["name"],
"description": function["description"],
"inputSchema": {
"json": {
"type": "object",
"properties": {},
"required": function["parameters"].get("required", []),
}
},
}
}
return self._parse_response(response)
for prop, details in function["parameters"].get("properties", {}).items():
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get("type", "string"),
"description": details.get("description", ""),
}
new_tools.append(new_tool)
return new_tools
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using AWS Bedrock.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
if tools:
# Use converse method when tools are provided
messages = [
{
"role": "user",
"content": [{"text": message["content"]} for message in messages],
}
]
inference_config = {
"temperature": self.model_kwargs["temperature"],
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
"topP": self.model_kwargs["top_p"],
}
tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse(
modelId=self.config.model,
messages=messages,
inferenceConfig=inference_config,
toolConfig=tools_config,
)
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)

View File

@@ -9,35 +9,17 @@ from mem0.llms.base import LLMBase
class AzureOpenAILLM(LLMBase):
"""
A class for interacting with Azure OpenAI models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AzureOpenAILLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
# Ensure model name is set; it should match the Azure OpenAI deployment name.
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model = "gpt-4o"
api_key = self.config.azure_kwargs.api_key or os.getenv(
"LLM_AZURE_OPENAI_API_KEY"
)
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv(
"LLM_AZURE_DEPLOYMENT"
)
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv(
"LLM_AZURE_ENDPOINT"
)
api_version = self.config.azure_kwargs.api_version or os.getenv(
"LLM_AZURE_API_VERSION"
)
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
default_headers = self.config.azure_kwargs.default_headers
self.client = AzureOpenAI(
@@ -49,20 +31,54 @@ class AzureOpenAILLM(LLMBase):
default_headers=default_headers,
)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response using Azure OpenAI based on the provided messages.
Generate a response based on the given messages using Azure OpenAI.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (Optional[str]): The desired format of the response. Defaults to None.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -71,8 +87,11 @@ class AzureOpenAILLM(LLMBase):
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content
return self._parse_response(response, tools)

View File

@@ -9,38 +9,20 @@ from mem0.llms.base import LLMBase
class AzureOpenAIStructuredLLM(LLMBase):
"""
A class for interacting with Azure OpenAI models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the AzureOpenAIStructuredLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
# Ensure model name is set; it should match the Azure OpenAI deployment name.
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model = "gpt-4o-2024-08-06"
api_key = (
os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
)
azure_deployment = (
os.getenv("LLM_AZURE_DEPLOYMENT")
or self.config.azure_kwargs.azure_deployment
)
azure_endpoint = (
os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
)
api_version = (
os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
)
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
default_headers = self.config.azure_kwargs.default_headers
# Can display a warning if API version is of model and api-version
self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
@@ -50,52 +32,20 @@ class AzureOpenAIStructuredLLM(LLMBase):
default_headers=default_headers,
)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str:
"""
Generates a response using Azure OpenAI based on the provided messages.
Generate a response based on the given messages using Azure OpenAI.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key.
tool_choice (str): The choice of tool to use. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -104,9 +54,11 @@ class AzureOpenAIStructuredLLM(LLMBase):
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
if tools:
params["tools"] = tools

View File

@@ -4,12 +4,8 @@ from pydantic import BaseModel, Field, field_validator
class LlmConfig(BaseModel):
provider: str = Field(
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
)
config: Optional[dict] = Field(
description="Configuration for the specific LLM", default={}
)
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
@field_validator("config")
def validate_config(cls, v, values):

View File

@@ -9,42 +9,64 @@ from mem0.llms.base import LLMBase
class DeepSeekLLM(LLMBase):
"""
A class for interacting with DeepSeek's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the DeepSeekLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
if not self.config.model:
self.config.model = "deepseek-chat"
api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY")
base_url = (
self.config.deepseek_base_url
or os.getenv("DEEPSEEK_API_BASE")
or "https://api.deepseek.com"
)
base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response using DeepSeek based on the provided messages.
Generate a response based on the given messages using DeepSeek.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -53,5 +75,10 @@ class DeepSeekLLM(LLMBase):
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content
return self._parse_response(response, tools)

View File

@@ -3,7 +3,8 @@ from typing import Dict, List, Optional
try:
import google.generativeai as genai
from google.generativeai import GenerativeModel
from google.generativeai import GenerativeModel, protos
from google.generativeai.types import content_types
except ImportError:
raise ImportError(
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
@@ -14,17 +15,7 @@ from mem0.llms.base import LLMBase
class GeminiLLM(LLMBase):
"""
A wrapper for Google's Gemini language model, integrating it with the LLMBase class.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the Gemini LLM with the provided configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration object for the model.
"""
super().__init__(config)
if not self.config.model:
@@ -34,25 +25,51 @@ class GeminiLLM(LLMBase):
genai.configure(api_key=api_key)
self.client = GenerativeModel(model_name=self.config.model)
def _reformat_messages(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
def _parse_response(self, response, tools):
"""
Reformats messages to match the Gemini API's expected structure.
Process the response based on whether tools are used or not.
Args:
messages (List[Dict[str, str]]): A list of messages with 'role' and 'content' keys.
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
List[Dict[str, str]]: Reformatted messages in the required format.
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
"tool_calls": [],
}
for part in response.candidates[0].content.parts:
if fn := part.function_call:
if isinstance(fn, protos.FunctionCall):
fn_call = type(fn).to_dict(fn)
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
continue
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
return processed_response
else:
return response.candidates[0].content.parts[0].text
def _reformat_messages(self, messages: List[Dict[str, str]]):
"""
Reformat messages for Gemini.
Args:
messages: The list of messages provided in the request.
Returns:
list: The list of messages in the required format.
"""
new_messages = []
for message in messages:
if message["role"] == "system":
content = (
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
)
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
else:
content = message["content"]
@@ -65,33 +82,90 @@ class GeminiLLM(LLMBase):
return new_messages
def generate_response(
self, messages: List[Dict[str, str]], response_format: Optional[Dict] = None
) -> str:
def _reformat_tools(self, tools: Optional[List[Dict]]):
"""
Generates a response from Gemini based on the given conversation history.
Reformat tools for Gemini.
Args:
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
response_format (Optional[Dict]): Specifies the response format (e.g., JSON schema).
tools: The list of tools provided in the request.
Returns:
str: The generated response as text.
list: The list of tools in the required format.
"""
def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
new_tools = []
if tools:
for tool in tools:
func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
# return content_types.to_function_library(new_tools)
return new_tools
else:
return None
def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using Gemini.
Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format for the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response.
"""
params = {
"temperature": self.config.temperature,
"max_output_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format and response_format.get("type") == "json_object":
if response_format is not None and response_format["type"] == "json_object":
params["response_mime_type"] = "application/json"
if "schema" in response_format:
params["response_schema"] = response_format["schema"]
if tool_choice:
tool_config = content_types.to_tool_config(
{
"function_calling_config": {
"mode": tool_choice,
"allowed_function_names": (
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
),
}
}
)
response = self.client.generate_content(
contents=self._reformat_messages(messages),
tools=self._reformat_tools(tools),
generation_config=genai.GenerationConfig(**params),
tool_config=tool_config,
)
return response.candidates[0].content.parts[0].text
return self._parse_response(response, tools)

View File

@@ -5,26 +5,14 @@ from typing import Dict, List, Optional
try:
from groq import Groq
except ImportError:
raise ImportError(
"The 'groq' library is required. Please install it using 'pip install groq'."
)
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class GroqLLM(LLMBase):
"""
A class for interacting with Groq's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the GroqLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
if not self.config.model:
@@ -33,20 +21,54 @@ class GroqLLM(LLMBase):
api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
self.client = Groq(api_key=api_key)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response using Groq based on the provided messages.
Generate a response based on the given messages using Groq.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (Optional[str]): The desired format of the response. Defaults to None.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -57,5 +79,9 @@ class GroqLLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content
return self._parse_response(response, tools)

View File

@@ -4,50 +4,70 @@ from typing import Dict, List, Optional
try:
import litellm
except ImportError:
raise ImportError(
"The 'litellm' library is required. Please install it using 'pip install litellm'."
)
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class LiteLLM(LLMBase):
"""
A class for interacting with LiteLLM's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the LiteLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
if not self.config.model:
self.config.model = "gpt-4o-mini"
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response using LiteLLM based on the provided messages.
Generate a response based on the given messages using Litellm.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (Optional[str]): The desired format of the response. Defaults to None.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
if not litellm.supports_function_calling(self.config.model):
raise ValueError(
f"Model '{self.config.model}' in LiteLLM does not support function calling."
)
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
params = {
"model": self.config.model,
@@ -58,6 +78,9 @@ class LiteLLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = litellm.completion(**params)
return response.choices[0].message.content
return self._parse_response(response, tools)

View File

@@ -3,56 +3,77 @@ from typing import Dict, List, Optional
try:
from ollama import Client
except ImportError:
raise ImportError(
"The 'ollama' library is required. Please install it using 'pip install ollama'."
)
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class OllamaLLM(LLMBase):
"""
A class for interacting with Ollama's language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the OllamaLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
if not self.config.model:
self.config.model = "llama3.1:70b"
self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists()
def _ensure_model_exists(self):
"""
Ensures the specified model exists locally. If not, pulls it from Ollama.
Ensure the specified model exists locally. If not, pull it from Ollama.
"""
local_models = self.client.list()["models"]
if not any(model.get("name") == self.config.model for model in local_models):
self.client.pull(self.config.model)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response["message"]["content"],
"tool_calls": [],
}
if response["message"].get("tool_calls"):
for tool_call in response["message"]["tool_calls"]:
processed_response["tool_calls"].append(
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
)
return processed_response
else:
return response["message"]["content"]
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response using Ollama based on the provided messages.
Generate a response based on the given messages using OpenAI.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (Optional[str]): The desired format of the response. Defaults to None.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -66,5 +87,8 @@ class OllamaLLM(LLMBase):
if response_format:
params["format"] = "json"
if tools:
params["tools"] = tools
response = self.client.chat(**params)
return response["message"]["content"]
return self._parse_response(response, tools)

View File

@@ -9,17 +9,7 @@ from mem0.llms.base import LLMBase
class OpenAILLM(LLMBase):
"""
A class to interact with OpenAI or OpenRouter APIs for generating responses using LLMs.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the OpenAILLM instance.
Args:
config (Optional[BaseLlmConfig]): Configuration for the LLM, including model, API key, and base URLs.
"""
super().__init__(config)
if not self.config.model:
@@ -34,27 +24,57 @@ class OpenAILLM(LLMBase):
)
else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = (
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response based on the provided messages using OpenAI or OpenRouter.
Generate a response based on the given messages using OpenAI.
Args:
messages (List[Dict[str, str]]): A list of message dictionaries containing 'role' and 'content'.
response_format (Optional[str]): The format of the response. Defaults to None.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -82,6 +102,9 @@ class OpenAILLM(LLMBase):
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content
return self._parse_response(response, tools)

View File

@@ -9,78 +9,31 @@ from mem0.llms.base import LLMBase
class OpenAIStructuredLLM(LLMBase):
"""
A class for interacting with OpenAI's structured language models using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the OpenAIStructuredLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
if not self.config.model:
self.config.model = "gpt-4o-2024-08-06"
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
base_url = (
self.config.openai_base_url
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools (list, optional): List of tools that the model can call.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
) -> str:
"""
Generates a response using OpenAI based on the provided messages.
Generate a response based on the given messages using OpenAI.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (Optional[str]): The desired format of the response. Defaults to None.
tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key.
tool_choice (str): The choice of tool to use. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -95,4 +48,4 @@ class OpenAIStructuredLLM(LLMBase):
params["tool_choice"] = tool_choice
response = self.client.beta.chat.completions.parse(**params)
return self._parse_response(response, tools)
return response.choices[0].message.content

View File

@@ -5,26 +5,14 @@ from typing import Dict, List, Optional
try:
from together import Together
except ImportError:
raise ImportError(
"The 'together' library is required. Please install it using 'pip install together'."
)
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class TogetherLLM(LLMBase):
"""
A class for interacting with the TogetherAI language model using the specified configuration.
"""
def __init__(self, config: Optional[BaseLlmConfig] = None):
"""
Initializes the TogetherLLM instance with the given configuration.
Args:
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
"""
super().__init__(config)
if not self.config.model:
@@ -33,20 +21,54 @@ class TogetherLLM(LLMBase):
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
self.client = Together(api_key=api_key)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
Args:
response: The raw response from API.
tools: The list of tools provided in the request.
Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
response_format: Optional[str] = None,
) -> str:
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generates a response using TogetherAI based on the provided messages.
Generate a response based on the given messages using TogetherAI.
Args:
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
response_format (Optional[str]): The desired format of the response. Defaults to None.
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".
Returns:
str: The generated response from the model.
str: The generated response.
"""
params = {
"model": self.config.model,
@@ -57,6 +79,9 @@ class TogetherLLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content
return self._parse_response(response, tools)

View File

@@ -15,11 +15,7 @@ class XAILLM(LLMBase):
self.config.model = "grok-2-latest"
api_key = self.config.api_key or os.getenv("XAI_API_KEY")
base_url = (
self.config.xai_base_url
or os.getenv("XAI_API_BASE")
or "https://api.x.ai/v1"
)
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
self.client = OpenAI(api_key=api_key, base_url=base_url)
def generate_response(self, messages: List[Dict[str, str]], response_format=None):

View File

@@ -20,10 +20,8 @@ def mock_openai_client():
yield mock_client
def test_generate_response(mock_openai_client):
config = BaseLlmConfig(
model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
def test_generate_response_without_tools(mock_openai_client):
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
llm = AzureOpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
@@ -31,21 +29,67 @@ def test_generate_response(mock_openai_client):
]
mock_response = Mock()
mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages)
mock_openai_client.chat.completions.create.assert_called_once_with(
model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_openai_client):
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
llm = AzureOpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_openai_client.chat.completions.create.assert_called_once_with(
model=MODEL,
messages=messages,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
top_p=TOP_P,
tools=tools,
tool_choice="auto",
)
assert response == "I'm doing well, thank you for asking!"
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
@pytest.mark.parametrize(
@@ -84,6 +128,4 @@ def test_generate_with_http_proxies(default_headers):
api_version=None,
default_headers=default_headers,
)
mock_http_client.assert_called_once_with(
proxies="http://testproxy.mem0.net:8000"
)
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")

View File

@@ -16,47 +16,33 @@ def mock_deepseek_client():
def test_deepseek_llm_base_url():
# case1: default config with deepseek official base url
config = BaseLlmConfig(
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
)
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == "https://api.deepseek.com"
# case2: with env variable DEEPSEEK_API_BASE
provider_base_url = "https://api.provider.com/v1/"
os.environ["DEEPSEEK_API_BASE"] = provider_base_url
config = BaseLlmConfig(
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
)
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == provider_base_url
# case3: with config.deepseek_base_url
config_base_url = "https://api.config.com/v1/"
config = BaseLlmConfig(
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
deepseek_base_url=config_base_url,
model="deepseek-chat",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
deepseek_base_url=config_base_url
)
llm = DeepSeekLLM(config)
assert str(llm.client.base_url) == config_base_url
def test_generate_response(mock_deepseek_client):
config = BaseLlmConfig(
model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0
)
def test_generate_response_without_tools(mock_deepseek_client):
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0)
llm = DeepSeekLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
@@ -64,18 +50,64 @@ def test_generate_response(mock_deepseek_client):
]
mock_response = Mock()
mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_deepseek_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages)
mock_deepseek_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
model="deepseek-chat", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_deepseek_client):
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0)
llm = DeepSeekLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_deepseek_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_deepseek_client.chat.completions.create.assert_called_once_with(
model="deepseek-chat",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto"
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -17,9 +17,7 @@ def mock_gemini_client():
def test_generate_response_without_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(
model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0
)
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
@@ -36,14 +34,86 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
mock_gemini_client.generate_content.assert_called_once_with(
contents=[
{
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
"role": "user",
},
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Hello, how are you?", "role": "user"},
],
generation_config=GenerationConfig(
temperature=0.7, max_output_tokens=100, top_p=1.0
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
tools=None,
tool_config=content_types.to_tool_config(
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
),
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_tool_call = Mock()
mock_tool_call.name = "add_memory"
mock_tool_call.args = {"data": "Today is a sunny day."}
mock_part = Mock()
mock_part.function_call = mock_tool_call
mock_part.text = "I've added the memory for you."
mock_content = Mock()
mock_content.parts = [mock_part]
mock_message = Mock()
mock_message.content = mock_content
mock_response = Mock(candidates=[mock_message])
mock_gemini_client.generate_content.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_gemini_client.generate_content.assert_called_once_with(
contents=[
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
],
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
tools=[
{
"function_declarations": [
{
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
}
]
}
],
tool_config=content_types.to_tool_config(
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
),
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -14,10 +14,8 @@ def mock_groq_client():
yield mock_client
def test_generate_response(mock_groq_client):
config = BaseLlmConfig(
model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0
)
def test_generate_response_without_tools(mock_groq_client):
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GroqLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
@@ -25,18 +23,64 @@ def test_generate_response(mock_groq_client):
]
mock_response = Mock()
mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages)
mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_groq_client):
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GroqLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_groq_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_groq_client.chat.completions.create.assert_called_once_with(
model="llama3-70b-8192",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto",
)
assert response == "I'm doing well, thank you for asking!"
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -13,22 +13,17 @@ def mock_litellm():
def test_generate_response_with_unsupported_model(mock_litellm):
config = BaseLlmConfig(
model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1
)
config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [{"role": "user", "content": "Hello"}]
mock_litellm.supports_function_calling.return_value = False
with pytest.raises(
ValueError,
match="Model 'unsupported-model' in LiteLLM does not support function calling.",
):
with pytest.raises(ValueError, match="Model 'unsupported-model' in litellm does not support function calling."):
llm.generate_response(messages)
def test_generate_response(mock_litellm):
def test_generate_response_without_tools(mock_litellm):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [
@@ -37,9 +32,7 @@ def test_generate_response(mock_litellm):
]
mock_response = Mock()
mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_litellm.completion.return_value = mock_response
mock_litellm.supports_function_calling.return_value = True
@@ -49,3 +42,50 @@ def test_generate_response(mock_litellm):
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_litellm):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
llm = litellm.LiteLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_litellm.completion.return_value = mock_response
mock_litellm.supports_function_calling.return_value = True
response = llm.generate_response(messages, tools=tools)
mock_litellm.completion.assert_called_once_with(
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1, tools=tools, tool_choice="auto"
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -16,9 +16,7 @@ def mock_openai_client():
def test_openai_llm_base_url():
# case1: default config: with openai official base url
config = BaseLlmConfig(
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
)
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == "https://api.openai.com/v1/"
@@ -26,9 +24,7 @@ def test_openai_llm_base_url():
# case2: with env variable OPENAI_API_BASE
provider_base_url = "https://api.provider.com/v1"
os.environ["OPENAI_API_BASE"] = provider_base_url
config = BaseLlmConfig(
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
)
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == provider_base_url + "/"
@@ -36,19 +32,14 @@ def test_openai_llm_base_url():
# case3: with config.openai_base_url
config_base_url = "https://api.config.com/v1"
config = BaseLlmConfig(
model="gpt-4o",
temperature=0.7,
max_tokens=100,
top_p=1.0,
api_key="api_key",
openai_base_url=config_base_url,
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", openai_base_url=config_base_url
)
llm = OpenAILLM(config)
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
assert str(llm.client.base_url) == config_base_url + "/"
def test_generate_response(mock_openai_client):
def test_generate_response_without_tools(mock_openai_client):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OpenAILLM(config)
messages = [
@@ -57,9 +48,7 @@ def test_generate_response(mock_openai_client):
]
mock_response = Mock()
mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages)
@@ -68,3 +57,49 @@ def test_generate_response(mock_openai_client):
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_openai_client):
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
llm = OpenAILLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_openai_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_openai_client.chat.completions.create.assert_called_once_with(
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
)
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}

View File

@@ -14,13 +14,8 @@ def mock_together_client():
yield mock_client
def test_generate_response(mock_together_client):
config = BaseLlmConfig(
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0.7,
max_tokens=100,
top_p=1.0,
)
def test_generate_response_without_tools(mock_together_client):
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
llm = TogetherLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
@@ -28,18 +23,64 @@ def test_generate_response(mock_together_client):
]
mock_response = Mock()
mock_response.choices = [
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
]
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
mock_together_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages)
mock_together_client.chat.completions.create.assert_called_once_with(
model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
)
assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_together_client):
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
llm = TogetherLLM(config)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
]
tools = [
{
"type": "function",
"function": {
"name": "add_memory",
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
},
},
}
]
mock_response = Mock()
mock_message = Mock()
mock_message.content = "I've added the memory for you."
mock_tool_call = Mock()
mock_tool_call.function.name = "add_memory"
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
mock_message.tool_calls = [mock_tool_call]
mock_response.choices = [Mock(message=mock_message)]
mock_together_client.chat.completions.create.return_value = mock_response
response = llm.generate_response(messages, tools=tools)
mock_together_client.chat.completions.create.assert_called_once_with(
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
messages=messages,
temperature=0.7,
max_tokens=100,
top_p=1.0,
tools=tools,
tool_choice="auto",
)
assert response == "I'm doing well, thank you for asking!"
assert response["content"] == "I've added the memory for you."
assert len(response["tool_calls"]) == 1
assert response["tool_calls"][0]["name"] == "add_memory"
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}