Reverting the tools commit (#2404)
This commit is contained in:
@@ -4,26 +4,14 @@ from typing import Dict, List, Optional
|
|||||||
try:
|
try:
|
||||||
import anthropic
|
import anthropic
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||||
"The 'anthropic' library is required. Please install it using 'pip install anthropic'."
|
|
||||||
)
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class AnthropicLLM(LLMBase):
|
class AnthropicLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with Anthropic's Claude models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the AnthropicLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
@@ -35,17 +23,23 @@ class AnthropicLLM(LLMBase):
|
|||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
) -> str:
|
response_format=None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response using Anthropic's Claude model based on the provided messages.
|
Generate a response based on the given messages using Anthropic.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
# Extract system message separately
|
# Separate system message from other messages
|
||||||
system_message = ""
|
system_message = ""
|
||||||
filtered_messages = []
|
filtered_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -62,6 +56,9 @@ class AnthropicLLM(LLMBase):
|
|||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_tokens,
|
||||||
"top_p": self.config.top_p,
|
"top_p": self.config.top_p,
|
||||||
}
|
}
|
||||||
|
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.messages.create(**params)
|
response = self.client.messages.create(**params)
|
||||||
return response.content[0].text
|
return response.content[0].text
|
||||||
|
|||||||
@@ -4,26 +4,14 @@ from typing import Any, Dict, List, Optional
|
|||||||
try:
|
try:
|
||||||
import boto3
|
import boto3
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("The 'boto3' library is required. Please install it using 'pip install boto3'.")
|
||||||
"The 'boto3' library is required. Please install it using 'pip install boto3'."
|
|
||||||
)
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class AWSBedrockLLM(LLMBase):
|
class AWSBedrockLLM(LLMBase):
|
||||||
"""
|
|
||||||
A wrapper for AWS Bedrock's language models, integrating them with the LLMBase class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the AWS Bedrock LLM with the provided configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration object for the model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
@@ -37,29 +25,49 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
|
|
||||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||||
"""
|
"""
|
||||||
Formats a list of messages into a structured prompt for the model.
|
Formats a list of messages into the required prompt structure for the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries containing 'role' and 'content'.
|
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
|
||||||
|
Each dictionary contains 'role' and 'content' keys.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
str: A formatted string combining all messages, structured with roles capitalized and separated by newlines.
|
||||||
"""
|
"""
|
||||||
formatted_messages = [
|
formatted_messages = []
|
||||||
f"\n\n{msg['role'].capitalize()}: {msg['content']}" for msg in messages
|
for message in messages:
|
||||||
]
|
role = message["role"].capitalize()
|
||||||
|
content = message["content"]
|
||||||
|
formatted_messages.append(f"\n\n{role}: {content}")
|
||||||
|
|
||||||
return "".join(formatted_messages) + "\n\nAssistant:"
|
return "".join(formatted_messages) + "\n\nAssistant:"
|
||||||
|
|
||||||
def _parse_response(self, response) -> str:
|
def _parse_response(self, response, tools) -> str:
|
||||||
"""
|
"""
|
||||||
Extracts the generated response from the API response.
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The raw response from the AWS Bedrock API.
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response text.
|
str or dict: The processed response.
|
||||||
"""
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {"tool_calls": []}
|
||||||
|
|
||||||
|
if response["output"]["message"]["content"]:
|
||||||
|
for item in response["output"]["message"]["content"]:
|
||||||
|
if "toolUse" in item:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": item["toolUse"]["name"],
|
||||||
|
"arguments": item["toolUse"]["input"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
|
||||||
response_body = json.loads(response["body"].read().decode())
|
response_body = json.loads(response["body"].read().decode())
|
||||||
return response_body.get("completion", "")
|
return response_body.get("completion", "")
|
||||||
|
|
||||||
@@ -68,21 +76,22 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
provider: str,
|
provider: str,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Prepares the input dictionary for the specified provider's model.
|
Prepares the input dictionary for the specified provider's model by mapping and renaming
|
||||||
|
keys in the input based on the provider's requirements.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider (str): The model provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
|
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
|
||||||
model (str): The model identifier.
|
model (str): The name or identifier of the model being used.
|
||||||
prompt (str): The input prompt.
|
prompt (str): The text prompt to be processed by the model.
|
||||||
model_kwargs (Optional[Dict[str, Any]]): Additional model parameters.
|
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: The prepared input dictionary.
|
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
|
||||||
"""
|
"""
|
||||||
model_kwargs = model_kwargs or {}
|
|
||||||
input_body = {"prompt": prompt, **model_kwargs}
|
input_body = {"prompt": prompt, **model_kwargs}
|
||||||
|
|
||||||
provider_mappings = {
|
provider_mappings = {
|
||||||
@@ -110,35 +119,102 @@ class AWSBedrockLLM(LLMBase):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
input_body["textGenerationConfig"] = {
|
input_body["textGenerationConfig"] = {
|
||||||
k: v
|
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
|
||||||
for k, v in input_body["textGenerationConfig"].items()
|
|
||||||
if v is not None
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return input_body
|
return input_body
|
||||||
|
|
||||||
def generate_response(self, messages: List[Dict[str, str]]) -> str:
|
def _convert_tool_format(self, original_tools):
|
||||||
"""
|
"""
|
||||||
Generates a response using AWS Bedrock based on the provided messages.
|
Converts a list of tools from their original format to a new standardized format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
|
original_tools (list): A list of dictionaries representing the original tools, each containing a 'type' key and corresponding details.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response text.
|
list: A list of dictionaries representing the tools in the new standardized format.
|
||||||
"""
|
"""
|
||||||
prompt = self._format_messages(messages)
|
new_tools = []
|
||||||
provider = self.config.model.split(".")[0]
|
|
||||||
input_body = self._prepare_input(
|
|
||||||
provider, self.config.model, prompt, self.model_kwargs
|
|
||||||
)
|
|
||||||
body = json.dumps(input_body)
|
|
||||||
|
|
||||||
response = self.client.invoke_model(
|
for tool in original_tools:
|
||||||
body=body,
|
if tool["type"] == "function":
|
||||||
modelId=self.config.model,
|
function = tool["function"]
|
||||||
accept="application/json",
|
new_tool = {
|
||||||
contentType="application/json",
|
"toolSpec": {
|
||||||
)
|
"name": function["name"],
|
||||||
|
"description": function["description"],
|
||||||
|
"inputSchema": {
|
||||||
|
"json": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": function["parameters"].get("required", []),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return self._parse_response(response)
|
for prop, details in function["parameters"].get("properties", {}).items():
|
||||||
|
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
|
||||||
|
"type": details.get("type", "string"),
|
||||||
|
"description": details.get("description", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
new_tools.append(new_tool)
|
||||||
|
|
||||||
|
return new_tools
|
||||||
|
|
||||||
|
def generate_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
response_format=None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a response based on the given messages using AWS Bedrock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Use converse method when tools are provided
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"text": message["content"]} for message in messages],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
inference_config = {
|
||||||
|
"temperature": self.model_kwargs["temperature"],
|
||||||
|
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
|
||||||
|
"topP": self.model_kwargs["top_p"],
|
||||||
|
}
|
||||||
|
tools_config = {"tools": self._convert_tool_format(tools)}
|
||||||
|
|
||||||
|
response = self.client.converse(
|
||||||
|
modelId=self.config.model,
|
||||||
|
messages=messages,
|
||||||
|
inferenceConfig=inference_config,
|
||||||
|
toolConfig=tools_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use invoke_model method when no tools are provided
|
||||||
|
prompt = self._format_messages(messages)
|
||||||
|
provider = self.model.split(".")[0]
|
||||||
|
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=self.model,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -9,35 +9,17 @@ from mem0.llms.base import LLMBase
|
|||||||
|
|
||||||
|
|
||||||
class AzureOpenAILLM(LLMBase):
|
class AzureOpenAILLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with Azure OpenAI models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the AzureOpenAILLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Ensure model name is set; it should match the Azure OpenAI deployment name.
|
# Model name should match the custom deployment name chosen for it.
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "gpt-4o"
|
self.config.model = "gpt-4o"
|
||||||
|
|
||||||
api_key = self.config.azure_kwargs.api_key or os.getenv(
|
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
||||||
"LLM_AZURE_OPENAI_API_KEY"
|
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
||||||
)
|
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
||||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv(
|
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
||||||
"LLM_AZURE_DEPLOYMENT"
|
|
||||||
)
|
|
||||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv(
|
|
||||||
"LLM_AZURE_ENDPOINT"
|
|
||||||
)
|
|
||||||
api_version = self.config.azure_kwargs.api_version or os.getenv(
|
|
||||||
"LLM_AZURE_API_VERSION"
|
|
||||||
)
|
|
||||||
default_headers = self.config.azure_kwargs.default_headers
|
default_headers = self.config.azure_kwargs.default_headers
|
||||||
|
|
||||||
self.client = AzureOpenAI(
|
self.client = AzureOpenAI(
|
||||||
@@ -49,20 +31,54 @@ class AzureOpenAILLM(LLMBase):
|
|||||||
default_headers=default_headers,
|
default_headers=default_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format=None,
|
||||||
) -> str:
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response using Azure OpenAI based on the provided messages.
|
Generate a response based on the given messages using Azure OpenAI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -71,8 +87,11 @@ class AzureOpenAILLM(LLMBase):
|
|||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_tokens,
|
||||||
"top_p": self.config.top_p,
|
"top_p": self.config.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
|
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response.choices[0].message.content
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -9,38 +9,20 @@ from mem0.llms.base import LLMBase
|
|||||||
|
|
||||||
|
|
||||||
class AzureOpenAIStructuredLLM(LLMBase):
|
class AzureOpenAIStructuredLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with Azure OpenAI models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the AzureOpenAIStructuredLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Ensure model name is set; it should match the Azure OpenAI deployment name.
|
# Model name should match the custom deployment name chosen for it.
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "gpt-4o-2024-08-06"
|
self.config.model = "gpt-4o-2024-08-06"
|
||||||
|
|
||||||
api_key = (
|
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
|
||||||
os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
|
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
|
||||||
)
|
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
|
||||||
azure_deployment = (
|
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
|
||||||
os.getenv("LLM_AZURE_DEPLOYMENT")
|
|
||||||
or self.config.azure_kwargs.azure_deployment
|
|
||||||
)
|
|
||||||
azure_endpoint = (
|
|
||||||
os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
|
|
||||||
)
|
|
||||||
api_version = (
|
|
||||||
os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
|
|
||||||
)
|
|
||||||
default_headers = self.config.azure_kwargs.default_headers
|
default_headers = self.config.azure_kwargs.default_headers
|
||||||
|
|
||||||
|
# Can display a warning if API version is of model and api-version
|
||||||
self.client = AzureOpenAI(
|
self.client = AzureOpenAI(
|
||||||
azure_deployment=azure_deployment,
|
azure_deployment=azure_deployment,
|
||||||
azure_endpoint=azure_endpoint,
|
azure_endpoint=azure_endpoint,
|
||||||
@@ -50,52 +32,20 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
|||||||
default_headers=default_headers,
|
default_headers=default_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
|
||||||
"""
|
|
||||||
Process the response based on whether tools are used or not.
|
|
||||||
Args:
|
|
||||||
response: The raw response from API.
|
|
||||||
tools: The list of tools provided in the request.
|
|
||||||
Returns:
|
|
||||||
str or dict: The processed response.
|
|
||||||
"""
|
|
||||||
if tools:
|
|
||||||
processed_response = {
|
|
||||||
"content": response.choices[0].message.content,
|
|
||||||
"tool_calls": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.choices[0].message.tool_calls:
|
|
||||||
for tool_call in response.choices[0].message.tool_calls:
|
|
||||||
processed_response["tool_calls"].append(
|
|
||||||
{
|
|
||||||
"name": tool_call.function.name,
|
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_response
|
|
||||||
else:
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format: Optional[str] = None,
|
||||||
tools: Optional[List[Dict]] = None,
|
|
||||||
tool_choice: str = "auto",
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a response using Azure OpenAI based on the provided messages.
|
Generate a response based on the given messages using Azure OpenAI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||||
tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key.
|
|
||||||
tool_choice (str): The choice of tool to use. Defaults to "auto".
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -104,9 +54,11 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
|||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_tokens,
|
||||||
"top_p": self.config.top_p,
|
"top_p": self.config.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
params["tools"] = tools
|
params["tools"] = tools
|
||||||
|
|||||||
@@ -4,12 +4,8 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
|
|
||||||
|
|
||||||
class LlmConfig(BaseModel):
|
class LlmConfig(BaseModel):
|
||||||
provider: str = Field(
|
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
|
||||||
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
|
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
|
||||||
)
|
|
||||||
config: Optional[dict] = Field(
|
|
||||||
description="Configuration for the specific LLM", default={}
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("config")
|
@field_validator("config")
|
||||||
def validate_config(cls, v, values):
|
def validate_config(cls, v, values):
|
||||||
|
|||||||
@@ -9,42 +9,64 @@ from mem0.llms.base import LLMBase
|
|||||||
|
|
||||||
|
|
||||||
class DeepSeekLLM(LLMBase):
|
class DeepSeekLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with DeepSeek's language models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the DeepSeekLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "deepseek-chat"
|
self.config.model = "deepseek-chat"
|
||||||
|
|
||||||
api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY")
|
api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||||
base_url = (
|
base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com"
|
||||||
self.config.deepseek_base_url
|
|
||||||
or os.getenv("DEEPSEEK_API_BASE")
|
|
||||||
or "https://api.deepseek.com"
|
|
||||||
)
|
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
) -> str:
|
response_format=None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response using DeepSeek based on the provided messages.
|
Generate a response based on the given messages using DeepSeek.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -53,5 +75,10 @@ class DeepSeekLLM(LLMBase):
|
|||||||
"max_tokens": self.config.max_tokens,
|
"max_tokens": self.config.max_tokens,
|
||||||
"top_p": self.config.top_p,
|
"top_p": self.config.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response.choices[0].message.content
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ from typing import Dict, List, Optional
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from google.generativeai import GenerativeModel
|
from google.generativeai import GenerativeModel, protos
|
||||||
|
from google.generativeai.types import content_types
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
||||||
@@ -14,17 +15,7 @@ from mem0.llms.base import LLMBase
|
|||||||
|
|
||||||
|
|
||||||
class GeminiLLM(LLMBase):
|
class GeminiLLM(LLMBase):
|
||||||
"""
|
|
||||||
A wrapper for Google's Gemini language model, integrating it with the LLMBase class.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the Gemini LLM with the provided configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration object for the model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
@@ -34,25 +25,51 @@ class GeminiLLM(LLMBase):
|
|||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
self.client = GenerativeModel(model_name=self.config.model)
|
self.client = GenerativeModel(model_name=self.config.model)
|
||||||
|
|
||||||
def _reformat_messages(
|
def _parse_response(self, response, tools):
|
||||||
self, messages: List[Dict[str, str]]
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
"""
|
"""
|
||||||
Reformats messages to match the Gemini API's expected structure.
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of messages with 'role' and 'content' keys.
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, str]]: Reformatted messages in the required format.
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for part in response.candidates[0].content.parts:
|
||||||
|
if fn := part.function_call:
|
||||||
|
if isinstance(fn, protos.FunctionCall):
|
||||||
|
fn_call = type(fn).to_dict(fn)
|
||||||
|
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
|
||||||
|
continue
|
||||||
|
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.candidates[0].content.parts[0].text
|
||||||
|
|
||||||
|
def _reformat_messages(self, messages: List[Dict[str, str]]):
|
||||||
|
"""
|
||||||
|
Reformat messages for Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The list of messages provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: The list of messages in the required format.
|
||||||
"""
|
"""
|
||||||
new_messages = []
|
new_messages = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "system":
|
if message["role"] == "system":
|
||||||
content = (
|
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||||
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
|
|
||||||
@@ -65,33 +82,90 @@ class GeminiLLM(LLMBase):
|
|||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
def generate_response(
|
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
||||||
self, messages: List[Dict[str, str]], response_format: Optional[Dict] = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Generates a response from Gemini based on the given conversation history.
|
Reformat tools for Gemini.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): List of message dictionaries containing 'role' and 'content'.
|
tools: The list of tools provided in the request.
|
||||||
response_format (Optional[Dict]): Specifies the response format (e.g., JSON schema).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response as text.
|
list: The list of tools in the required format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def remove_additional_properties(data):
|
||||||
|
"""Recursively removes 'additionalProperties' from nested dictionaries."""
|
||||||
|
|
||||||
|
if isinstance(data, dict):
|
||||||
|
filtered_dict = {
|
||||||
|
key: remove_additional_properties(value)
|
||||||
|
for key, value in data.items()
|
||||||
|
if not (key == "additionalProperties")
|
||||||
|
}
|
||||||
|
return filtered_dict
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
new_tools = []
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
func = tool["function"].copy()
|
||||||
|
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||||
|
|
||||||
|
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
|
||||||
|
# return content_types.to_function_library(new_tools)
|
||||||
|
|
||||||
|
return new_tools
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def generate_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
response_format=None,
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a response based on the given messages using Gemini.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
|
response_format (str or object, optional): Format for the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The generated response.
|
||||||
|
"""
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"temperature": self.config.temperature,
|
"temperature": self.config.temperature,
|
||||||
"max_output_tokens": self.config.max_tokens,
|
"max_output_tokens": self.config.max_tokens,
|
||||||
"top_p": self.config.top_p,
|
"top_p": self.config.top_p,
|
||||||
}
|
}
|
||||||
|
|
||||||
if response_format and response_format.get("type") == "json_object":
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
params["response_mime_type"] = "application/json"
|
params["response_mime_type"] = "application/json"
|
||||||
if "schema" in response_format:
|
if "schema" in response_format:
|
||||||
params["response_schema"] = response_format["schema"]
|
params["response_schema"] = response_format["schema"]
|
||||||
|
if tool_choice:
|
||||||
|
tool_config = content_types.to_tool_config(
|
||||||
|
{
|
||||||
|
"function_calling_config": {
|
||||||
|
"mode": tool_choice,
|
||||||
|
"allowed_function_names": (
|
||||||
|
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
response = self.client.generate_content(
|
response = self.client.generate_content(
|
||||||
contents=self._reformat_messages(messages),
|
contents=self._reformat_messages(messages),
|
||||||
|
tools=self._reformat_tools(tools),
|
||||||
generation_config=genai.GenerationConfig(**params),
|
generation_config=genai.GenerationConfig(**params),
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response.candidates[0].content.parts[0].text
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -5,26 +5,14 @@ from typing import Dict, List, Optional
|
|||||||
try:
|
try:
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("The 'groq' library is required. Please install it using 'pip install groq'.")
|
||||||
"The 'groq' library is required. Please install it using 'pip install groq'."
|
|
||||||
)
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class GroqLLM(LLMBase):
|
class GroqLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with Groq's language models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the GroqLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
@@ -33,20 +21,54 @@ class GroqLLM(LLMBase):
|
|||||||
api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
|
api_key = self.config.api_key or os.getenv("GROQ_API_KEY")
|
||||||
self.client = Groq(api_key=api_key)
|
self.client = Groq(api_key=api_key)
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format=None,
|
||||||
) -> str:
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response using Groq based on the provided messages.
|
Generate a response based on the given messages using Groq.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -57,5 +79,9 @@ class GroqLLM(LLMBase):
|
|||||||
}
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response.choices[0].message.content
|
return self._parse_response(response, tools)
|
||||||
@@ -4,50 +4,70 @@ from typing import Dict, List, Optional
|
|||||||
try:
|
try:
|
||||||
import litellm
|
import litellm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
|
||||||
"The 'litellm' library is required. Please install it using 'pip install litellm'."
|
|
||||||
)
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM(LLMBase):
|
class LiteLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with LiteLLM's language models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the LiteLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "gpt-4o-mini"
|
self.config.model = "gpt-4o-mini"
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format=None,
|
||||||
) -> str:
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response using LiteLLM based on the provided messages.
|
Generate a response based on the given messages using Litellm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
if not litellm.supports_function_calling(self.config.model):
|
if not litellm.supports_function_calling(self.config.model):
|
||||||
raise ValueError(
|
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||||
f"Model '{self.config.model}' in LiteLLM does not support function calling."
|
|
||||||
)
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -58,6 +78,9 @@ class LiteLLM(LLMBase):
|
|||||||
}
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
|
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
return response.choices[0].message.content
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -3,56 +3,77 @@ from typing import Dict, List, Optional
|
|||||||
try:
|
try:
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("The 'ollama' library is required. Please install it using 'pip install ollama'.")
|
||||||
"The 'ollama' library is required. Please install it using 'pip install ollama'."
|
|
||||||
)
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class OllamaLLM(LLMBase):
|
class OllamaLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with Ollama's language models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the OllamaLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "llama3.1:70b"
|
self.config.model = "llama3.1:70b"
|
||||||
|
|
||||||
self.client = Client(host=self.config.ollama_base_url)
|
self.client = Client(host=self.config.ollama_base_url)
|
||||||
self._ensure_model_exists()
|
self._ensure_model_exists()
|
||||||
|
|
||||||
def _ensure_model_exists(self):
|
def _ensure_model_exists(self):
|
||||||
"""
|
"""
|
||||||
Ensures the specified model exists locally. If not, pulls it from Ollama.
|
Ensure the specified model exists locally. If not, pull it from Ollama.
|
||||||
"""
|
"""
|
||||||
local_models = self.client.list()["models"]
|
local_models = self.client.list()["models"]
|
||||||
if not any(model.get("name") == self.config.model for model in local_models):
|
if not any(model.get("name") == self.config.model for model in local_models):
|
||||||
self.client.pull(self.config.model)
|
self.client.pull(self.config.model)
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response["message"]["content"],
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response["message"].get("tool_calls"):
|
||||||
|
for tool_call in response["message"]["tool_calls"]:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call["function"]["name"],
|
||||||
|
"arguments": tool_call["function"]["arguments"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response["message"]["content"]
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format=None,
|
||||||
) -> str:
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response using Ollama based on the provided messages.
|
Generate a response based on the given messages using OpenAI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -66,5 +87,8 @@ class OllamaLLM(LLMBase):
|
|||||||
if response_format:
|
if response_format:
|
||||||
params["format"] = "json"
|
params["format"] = "json"
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
|
||||||
response = self.client.chat(**params)
|
response = self.client.chat(**params)
|
||||||
return response["message"]["content"]
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -9,17 +9,7 @@ from mem0.llms.base import LLMBase
|
|||||||
|
|
||||||
|
|
||||||
class OpenAILLM(LLMBase):
|
class OpenAILLM(LLMBase):
|
||||||
"""
|
|
||||||
A class to interact with OpenAI or OpenRouter APIs for generating responses using LLMs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the OpenAILLM instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration for the LLM, including model, API key, and base URLs.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
@@ -34,27 +24,57 @@ class OpenAILLM(LLMBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
base_url = (
|
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
|
||||||
self.config.openai_base_url
|
|
||||||
or os.getenv("OPENAI_API_BASE")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
)
|
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format=None,
|
||||||
) -> str:
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response based on the provided messages using OpenAI or OpenRouter.
|
Generate a response based on the given messages using OpenAI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of message dictionaries containing 'role' and 'content'.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
response_format (Optional[str]): The format of the response. Defaults to None.
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -82,6 +102,9 @@ class OpenAILLM(LLMBase):
|
|||||||
|
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
|
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response.choices[0].message.content
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -9,78 +9,31 @@ from mem0.llms.base import LLMBase
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIStructuredLLM(LLMBase):
|
class OpenAIStructuredLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with OpenAI's structured language models using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the OpenAIStructuredLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
self.config.model = "gpt-4o-2024-08-06"
|
self.config.model = "gpt-4o-2024-08-06"
|
||||||
|
|
||||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
base_url = (
|
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
|
||||||
self.config.openai_base_url
|
|
||||||
or os.getenv("OPENAI_API_BASE")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
)
|
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
def _parse_response(self, response, tools):
|
|
||||||
"""
|
|
||||||
Process the response based on whether tools are used or not.
|
|
||||||
Args:
|
|
||||||
response: The raw response from API.
|
|
||||||
tools (list, optional): List of tools that the model can call.
|
|
||||||
Returns:
|
|
||||||
str or dict: The processed response.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
processed_response = {
|
|
||||||
"content": response.choices[0].message.content,
|
|
||||||
"tool_calls": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.choices[0].message.tool_calls:
|
|
||||||
for tool_call in response.choices[0].message.tool_calls:
|
|
||||||
processed_response["tool_calls"].append(
|
|
||||||
{
|
|
||||||
"name": tool_call.function.name,
|
|
||||||
"arguments": json.loads(tool_call.function.arguments),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return processed_response
|
|
||||||
|
|
||||||
else:
|
|
||||||
return response.choices[0].message.content
|
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format: Optional[str] = None,
|
||||||
tools: Optional[List[Dict]] = None,
|
|
||||||
tool_choice: str = "auto",
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a response using OpenAI based on the provided messages.
|
Generate a response based on the given messages using OpenAI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
||||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
||||||
tools (Optional[List[Dict]]): A list of dictionaries, each containing a 'name' and 'arguments' key.
|
|
||||||
tool_choice (str): The choice of tool to use. Defaults to "auto".
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -95,4 +48,4 @@ class OpenAIStructuredLLM(LLMBase):
|
|||||||
params["tool_choice"] = tool_choice
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.beta.chat.completions.parse(**params)
|
response = self.client.beta.chat.completions.parse(**params)
|
||||||
return self._parse_response(response, tools)
|
return response.choices[0].message.content
|
||||||
|
|||||||
@@ -5,26 +5,14 @@ from typing import Dict, List, Optional
|
|||||||
try:
|
try:
|
||||||
from together import Together
|
from together import Together
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("The 'together' library is required. Please install it using 'pip install together'.")
|
||||||
"The 'together' library is required. Please install it using 'pip install together'."
|
|
||||||
)
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.base import LLMBase
|
from mem0.llms.base import LLMBase
|
||||||
|
|
||||||
|
|
||||||
class TogetherLLM(LLMBase):
|
class TogetherLLM(LLMBase):
|
||||||
"""
|
|
||||||
A class for interacting with the TogetherAI language model using the specified configuration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||||
"""
|
|
||||||
Initializes the TogetherLLM instance with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (Optional[BaseLlmConfig]): Configuration settings for the language model.
|
|
||||||
"""
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
if not self.config.model:
|
if not self.config.model:
|
||||||
@@ -33,20 +21,54 @@ class TogetherLLM(LLMBase):
|
|||||||
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
|
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||||
self.client = Together(api_key=api_key)
|
self.client = Together(api_key=api_key)
|
||||||
|
|
||||||
|
def _parse_response(self, response, tools):
|
||||||
|
"""
|
||||||
|
Process the response based on whether tools are used or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from API.
|
||||||
|
tools: The list of tools provided in the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or dict: The processed response.
|
||||||
|
"""
|
||||||
|
if tools:
|
||||||
|
processed_response = {
|
||||||
|
"content": response.choices[0].message.content,
|
||||||
|
"tool_calls": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.choices[0].message.tool_calls:
|
||||||
|
for tool_call in response.choices[0].message.tool_calls:
|
||||||
|
processed_response["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"name": tool_call.function.name,
|
||||||
|
"arguments": json.loads(tool_call.function.arguments),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_response
|
||||||
|
else:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
def generate_response(
|
def generate_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
response_format: Optional[str] = None,
|
response_format=None,
|
||||||
) -> str:
|
tools: Optional[List[Dict]] = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Generates a response using TogetherAI based on the provided messages.
|
Generate a response based on the given messages using TogetherAI.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[Dict[str, str]]): A list of dictionaries, each containing a 'role' and 'content' key.
|
messages (list): List of message dicts containing 'role' and 'content'.
|
||||||
response_format (Optional[str]): The desired format of the response. Defaults to None.
|
response_format (str or object, optional): Format of the response. Defaults to "text".
|
||||||
|
tools (list, optional): List of tools that the model can call. Defaults to None.
|
||||||
|
tool_choice (str, optional): Tool choice method. Defaults to "auto".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The generated response from the model.
|
str: The generated response.
|
||||||
"""
|
"""
|
||||||
params = {
|
params = {
|
||||||
"model": self.config.model,
|
"model": self.config.model,
|
||||||
@@ -57,6 +79,9 @@ class TogetherLLM(LLMBase):
|
|||||||
}
|
}
|
||||||
if response_format:
|
if response_format:
|
||||||
params["response_format"] = response_format
|
params["response_format"] = response_format
|
||||||
|
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**params)
|
response = self.client.chat.completions.create(**params)
|
||||||
return response.choices[0].message.content
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -15,11 +15,7 @@ class XAILLM(LLMBase):
|
|||||||
self.config.model = "grok-2-latest"
|
self.config.model = "grok-2-latest"
|
||||||
|
|
||||||
api_key = self.config.api_key or os.getenv("XAI_API_KEY")
|
api_key = self.config.api_key or os.getenv("XAI_API_KEY")
|
||||||
base_url = (
|
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
|
||||||
self.config.xai_base_url
|
|
||||||
or os.getenv("XAI_API_BASE")
|
|
||||||
or "https://api.x.ai/v1"
|
|
||||||
)
|
|
||||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
def generate_response(self, messages: List[Dict[str, str]], response_format=None):
|
def generate_response(self, messages: List[Dict[str, str]], response_format=None):
|
||||||
|
|||||||
@@ -20,10 +20,8 @@ def mock_openai_client():
|
|||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response(mock_openai_client):
|
def test_generate_response_without_tools(mock_openai_client):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
|
||||||
model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
|
|
||||||
)
|
|
||||||
llm = AzureOpenAILLM(config)
|
llm = AzureOpenAILLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -31,21 +29,67 @@ def test_generate_response(mock_openai_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [
|
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
|
||||||
]
|
|
||||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages)
|
response = llm.generate_response(messages)
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||||
|
model=MODEL, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P
|
||||||
|
)
|
||||||
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_response_with_tools(mock_openai_client):
|
||||||
|
config = BaseLlmConfig(model=MODEL, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)
|
||||||
|
llm = AzureOpenAILLM(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
|
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
mock_openai_client.chat.completions.create.assert_called_once_with(
|
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=TEMPERATURE,
|
temperature=TEMPERATURE,
|
||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
top_p=TOP_P,
|
top_p=TOP_P,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -84,6 +128,4 @@ def test_generate_with_http_proxies(default_headers):
|
|||||||
api_version=None,
|
api_version=None,
|
||||||
default_headers=default_headers,
|
default_headers=default_headers,
|
||||||
)
|
)
|
||||||
mock_http_client.assert_called_once_with(
|
mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")
|
||||||
proxies="http://testproxy.mem0.net:8000"
|
|
||||||
)
|
|
||||||
@@ -16,47 +16,33 @@ def mock_deepseek_client():
|
|||||||
|
|
||||||
def test_deepseek_llm_base_url():
|
def test_deepseek_llm_base_url():
|
||||||
# case1: default config with deepseek official base url
|
# case1: default config with deepseek official base url
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||||
model="deepseek-chat",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
top_p=1.0,
|
|
||||||
api_key="api_key",
|
|
||||||
)
|
|
||||||
llm = DeepSeekLLM(config)
|
llm = DeepSeekLLM(config)
|
||||||
assert str(llm.client.base_url) == "https://api.deepseek.com"
|
assert str(llm.client.base_url) == "https://api.deepseek.com"
|
||||||
|
|
||||||
# case2: with env variable DEEPSEEK_API_BASE
|
# case2: with env variable DEEPSEEK_API_BASE
|
||||||
provider_base_url = "https://api.provider.com/v1/"
|
provider_base_url = "https://api.provider.com/v1/"
|
||||||
os.environ["DEEPSEEK_API_BASE"] = provider_base_url
|
os.environ["DEEPSEEK_API_BASE"] = provider_base_url
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||||
model="deepseek-chat",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
top_p=1.0,
|
|
||||||
api_key="api_key",
|
|
||||||
)
|
|
||||||
llm = DeepSeekLLM(config)
|
llm = DeepSeekLLM(config)
|
||||||
assert str(llm.client.base_url) == provider_base_url
|
assert str(llm.client.base_url) == provider_base_url
|
||||||
|
|
||||||
# case3: with config.deepseek_base_url
|
# case3: with config.deepseek_base_url
|
||||||
config_base_url = "https://api.config.com/v1/"
|
config_base_url = "https://api.config.com/v1/"
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(
|
||||||
model="deepseek-chat",
|
model="deepseek-chat",
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
api_key="api_key",
|
api_key="api_key",
|
||||||
deepseek_base_url=config_base_url,
|
deepseek_base_url=config_base_url
|
||||||
)
|
)
|
||||||
llm = DeepSeekLLM(config)
|
llm = DeepSeekLLM(config)
|
||||||
assert str(llm.client.base_url) == config_base_url
|
assert str(llm.client.base_url) == config_base_url
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response(mock_deepseek_client):
|
def test_generate_response_without_tools(mock_deepseek_client):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0
|
|
||||||
)
|
|
||||||
llm = DeepSeekLLM(config)
|
llm = DeepSeekLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -64,18 +50,64 @@ def test_generate_response(mock_deepseek_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [
|
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
|
||||||
]
|
|
||||||
mock_deepseek_client.chat.completions.create.return_value = mock_response
|
mock_deepseek_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages)
|
response = llm.generate_response(messages)
|
||||||
|
|
||||||
mock_deepseek_client.chat.completions.create.assert_called_once_with(
|
mock_deepseek_client.chat.completions.create.assert_called_once_with(
|
||||||
model="deepseek-chat",
|
model="deepseek-chat", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||||
messages=messages,
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
top_p=1.0,
|
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_response_with_tools(mock_deepseek_client):
|
||||||
|
config = BaseLlmConfig(model="deepseek-chat", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = DeepSeekLLM(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
|
mock_deepseek_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
|
mock_deepseek_client.chat.completions.create.assert_called_once_with(
|
||||||
|
model="deepseek-chat",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=100,
|
||||||
|
top_p=1.0,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
@@ -17,9 +17,7 @@ def mock_gemini_client():
|
|||||||
|
|
||||||
|
|
||||||
def test_generate_response_without_tools(mock_gemini_client: Mock):
|
def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0
|
|
||||||
)
|
|
||||||
llm = GeminiLLM(config)
|
llm = GeminiLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -36,14 +34,86 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
|||||||
|
|
||||||
mock_gemini_client.generate_content.assert_called_once_with(
|
mock_gemini_client.generate_content.assert_called_once_with(
|
||||||
contents=[
|
contents=[
|
||||||
{
|
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||||
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
|
|
||||||
"role": "user",
|
|
||||||
},
|
|
||||||
{"parts": "Hello, how are you?", "role": "user"},
|
{"parts": "Hello, how are you?", "role": "user"},
|
||||||
],
|
],
|
||||||
generation_config=GenerationConfig(
|
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
||||||
temperature=0.7, max_output_tokens=100, top_p=1.0
|
tools=None,
|
||||||
|
tool_config=content_types.to_tool_config(
|
||||||
|
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||||
|
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = GeminiLLM(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.name = "add_memory"
|
||||||
|
mock_tool_call.args = {"data": "Today is a sunny day."}
|
||||||
|
|
||||||
|
mock_part = Mock()
|
||||||
|
mock_part.function_call = mock_tool_call
|
||||||
|
mock_part.text = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_content = Mock()
|
||||||
|
mock_content.parts = [mock_part]
|
||||||
|
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = mock_content
|
||||||
|
|
||||||
|
mock_response = Mock(candidates=[mock_message])
|
||||||
|
mock_gemini_client.generate_content.return_value = mock_response
|
||||||
|
|
||||||
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
|
mock_gemini_client.generate_content.assert_called_once_with(
|
||||||
|
contents=[
|
||||||
|
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||||
|
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
|
||||||
|
],
|
||||||
|
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"function_declarations": [
|
||||||
|
{
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_config=content_types.to_tool_config(
|
||||||
|
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
|
|||||||
@@ -14,10 +14,8 @@ def mock_groq_client():
|
|||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response(mock_groq_client):
|
def test_generate_response_without_tools(mock_groq_client):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0
|
|
||||||
)
|
|
||||||
llm = GroqLLM(config)
|
llm = GroqLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -25,18 +23,64 @@ def test_generate_response(mock_groq_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [
|
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
|
||||||
]
|
|
||||||
mock_groq_client.chat.completions.create.return_value = mock_response
|
mock_groq_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages)
|
response = llm.generate_response(messages)
|
||||||
|
|
||||||
|
mock_groq_client.chat.completions.create.assert_called_once_with(
|
||||||
|
model="llama3-70b-8192", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||||
|
)
|
||||||
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_response_with_tools(mock_groq_client):
|
||||||
|
config = BaseLlmConfig(model="llama3-70b-8192", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = GroqLLM(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
|
mock_groq_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
mock_groq_client.chat.completions.create.assert_called_once_with(
|
mock_groq_client.chat.completions.create.assert_called_once_with(
|
||||||
model="llama3-70b-8192",
|
model="llama3-70b-8192",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
|
|||||||
@@ -13,22 +13,17 @@ def mock_litellm():
|
|||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_unsupported_model(mock_litellm):
|
def test_generate_response_with_unsupported_model(mock_litellm):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1)
|
||||||
model="unsupported-model", temperature=0.7, max_tokens=100, top_p=1
|
|
||||||
)
|
|
||||||
llm = litellm.LiteLLM(config)
|
llm = litellm.LiteLLM(config)
|
||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
mock_litellm.supports_function_calling.return_value = False
|
mock_litellm.supports_function_calling.return_value = False
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(ValueError, match="Model 'unsupported-model' in litellm does not support function calling."):
|
||||||
ValueError,
|
|
||||||
match="Model 'unsupported-model' in LiteLLM does not support function calling.",
|
|
||||||
):
|
|
||||||
llm.generate_response(messages)
|
llm.generate_response(messages)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response(mock_litellm):
|
def test_generate_response_without_tools(mock_litellm):
|
||||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
|
||||||
llm = litellm.LiteLLM(config)
|
llm = litellm.LiteLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
@@ -37,9 +32,7 @@ def test_generate_response(mock_litellm):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [
|
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
|
||||||
]
|
|
||||||
mock_litellm.completion.return_value = mock_response
|
mock_litellm.completion.return_value = mock_response
|
||||||
mock_litellm.supports_function_calling.return_value = True
|
mock_litellm.supports_function_calling.return_value = True
|
||||||
|
|
||||||
@@ -49,3 +42,50 @@ def test_generate_response(mock_litellm):
|
|||||||
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_response_with_tools(mock_litellm):
|
||||||
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1)
|
||||||
|
llm = litellm.LiteLLM(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
|
mock_litellm.completion.return_value = mock_response
|
||||||
|
mock_litellm.supports_function_calling.return_value = True
|
||||||
|
|
||||||
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
|
mock_litellm.completion.assert_called_once_with(
|
||||||
|
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1, tools=tools, tool_choice="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ def mock_openai_client():
|
|||||||
|
|
||||||
def test_openai_llm_base_url():
|
def test_openai_llm_base_url():
|
||||||
# case1: default config: with openai official base url
|
# case1: default config: with openai official base url
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||||
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
|
|
||||||
)
|
|
||||||
llm = OpenAILLM(config)
|
llm = OpenAILLM(config)
|
||||||
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
||||||
assert str(llm.client.base_url) == "https://api.openai.com/v1/"
|
assert str(llm.client.base_url) == "https://api.openai.com/v1/"
|
||||||
@@ -26,9 +24,7 @@ def test_openai_llm_base_url():
|
|||||||
# case2: with env variable OPENAI_API_BASE
|
# case2: with env variable OPENAI_API_BASE
|
||||||
provider_base_url = "https://api.provider.com/v1"
|
provider_base_url = "https://api.provider.com/v1"
|
||||||
os.environ["OPENAI_API_BASE"] = provider_base_url
|
os.environ["OPENAI_API_BASE"] = provider_base_url
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key")
|
||||||
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key"
|
|
||||||
)
|
|
||||||
llm = OpenAILLM(config)
|
llm = OpenAILLM(config)
|
||||||
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
||||||
assert str(llm.client.base_url) == provider_base_url + "/"
|
assert str(llm.client.base_url) == provider_base_url + "/"
|
||||||
@@ -36,19 +32,14 @@ def test_openai_llm_base_url():
|
|||||||
# case3: with config.openai_base_url
|
# case3: with config.openai_base_url
|
||||||
config_base_url = "https://api.config.com/v1"
|
config_base_url = "https://api.config.com/v1"
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(
|
||||||
model="gpt-4o",
|
model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0, api_key="api_key", openai_base_url=config_base_url
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
top_p=1.0,
|
|
||||||
api_key="api_key",
|
|
||||||
openai_base_url=config_base_url,
|
|
||||||
)
|
)
|
||||||
llm = OpenAILLM(config)
|
llm = OpenAILLM(config)
|
||||||
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
# Note: openai client will parse the raw base_url into a URL object, which will have a trailing slash
|
||||||
assert str(llm.client.base_url) == config_base_url + "/"
|
assert str(llm.client.base_url) == config_base_url + "/"
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response(mock_openai_client):
|
def test_generate_response_without_tools(mock_openai_client):
|
||||||
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
llm = OpenAILLM(config)
|
llm = OpenAILLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
@@ -57,9 +48,7 @@ def test_generate_response(mock_openai_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [
|
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
|
||||||
]
|
|
||||||
mock_openai_client.chat.completions.create.return_value = mock_response
|
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages)
|
response = llm.generate_response(messages)
|
||||||
@@ -68,3 +57,49 @@ def test_generate_response(mock_openai_client):
|
|||||||
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_response_with_tools(mock_openai_client):
|
||||||
|
config = BaseLlmConfig(model="gpt-4o", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = OpenAILLM(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
|
mock_openai_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
|
mock_openai_client.chat.completions.create.assert_called_once_with(
|
||||||
|
model="gpt-4o", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0, tools=tools, tool_choice="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
|
|||||||
@@ -14,13 +14,8 @@ def mock_together_client():
|
|||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response(mock_together_client):
|
def test_generate_response_without_tools(mock_together_client):
|
||||||
config = BaseLlmConfig(
|
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
||||||
temperature=0.7,
|
|
||||||
max_tokens=100,
|
|
||||||
top_p=1.0,
|
|
||||||
)
|
|
||||||
llm = TogetherLLM(config)
|
llm = TogetherLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -28,18 +23,64 @@ def test_generate_response(mock_together_client):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.choices = [
|
mock_response.choices = [Mock(message=Mock(content="I'm doing well, thank you for asking!"))]
|
||||||
Mock(message=Mock(content="I'm doing well, thank you for asking!"))
|
|
||||||
]
|
|
||||||
mock_together_client.chat.completions.create.return_value = mock_response
|
mock_together_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
response = llm.generate_response(messages)
|
response = llm.generate_response(messages)
|
||||||
|
|
||||||
|
mock_together_client.chat.completions.create.assert_called_once_with(
|
||||||
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1", messages=messages, temperature=0.7, max_tokens=100, top_p=1.0
|
||||||
|
)
|
||||||
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_response_with_tools(mock_together_client):
|
||||||
|
config = BaseLlmConfig(model="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
|
llm = TogetherLLM(config)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Add a new memory: Today is a sunny day."},
|
||||||
|
]
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "add_memory",
|
||||||
|
"description": "Add a memory",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||||
|
"required": ["data"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_message = Mock()
|
||||||
|
mock_message.content = "I've added the memory for you."
|
||||||
|
|
||||||
|
mock_tool_call = Mock()
|
||||||
|
mock_tool_call.function.name = "add_memory"
|
||||||
|
mock_tool_call.function.arguments = '{"data": "Today is a sunny day."}'
|
||||||
|
|
||||||
|
mock_message.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.choices = [Mock(message=mock_message)]
|
||||||
|
mock_together_client.chat.completions.create.return_value = mock_response
|
||||||
|
|
||||||
|
response = llm.generate_response(messages, tools=tools)
|
||||||
|
|
||||||
mock_together_client.chat.completions.create.assert_called_once_with(
|
mock_together_client.chat.completions.create.assert_called_once_with(
|
||||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
)
|
)
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
|
||||||
|
assert response["content"] == "I've added the memory for you."
|
||||||
|
assert len(response["tool_calls"]) == 1
|
||||||
|
assert response["tool_calls"][0]["name"] == "add_memory"
|
||||||
|
assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."}
|
||||||
|
|||||||
Reference in New Issue
Block a user