[Mem0] Update dependencies and make the package lighter (#1708)
Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
@@ -5,22 +5,30 @@ from typing import Dict, List, Optional, Any
|
||||
try:
|
||||
import boto3
|
||||
except ImportError:
|
||||
raise ImportError("AWS Bedrock requires extra dependencies. Install with `pip install boto3`") from None
|
||||
raise ImportError(
|
||||
"AWS Bedrock requires extra dependencies. Install with `pip install boto3`"
|
||||
) from None
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
|
||||
class AWSBedrockLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
|
||||
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
self.client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=os.environ.get("AWS_REGION"),
|
||||
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"),
|
||||
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
)
|
||||
self.model_kwargs = {
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens_to_sample": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens_to_sample": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
@@ -28,7 +36,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
Formats a list of messages into the required prompt structure for the model.
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
|
||||
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
|
||||
Each dictionary contains 'role' and 'content' keys.
|
||||
|
||||
Returns:
|
||||
@@ -36,12 +44,12 @@ class AWSBedrockLLM(LLMBase):
|
||||
"""
|
||||
formatted_messages = []
|
||||
for message in messages:
|
||||
role = message['role'].capitalize()
|
||||
content = message['content']
|
||||
role = message["role"].capitalize()
|
||||
content = message["content"]
|
||||
formatted_messages.append(f"\n\n{role}: {content}")
|
||||
|
||||
|
||||
return "".join(formatted_messages) + "\n\nAssistant:"
|
||||
|
||||
|
||||
def _parse_response(self, response, tools) -> str:
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -54,72 +62,76 @@ class AWSBedrockLLM(LLMBase):
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"tool_calls": []
|
||||
}
|
||||
|
||||
processed_response = {"tool_calls": []}
|
||||
|
||||
if response["output"]["message"]["content"]:
|
||||
for item in response["output"]["message"]["content"]:
|
||||
if "toolUse" in item:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": item["toolUse"]["input"]
|
||||
})
|
||||
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": item["toolUse"]["name"],
|
||||
"arguments": item["toolUse"]["input"],
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
response_body = json.loads(response['body'].read().decode())
|
||||
return response_body.get('completion', '')
|
||||
|
||||
|
||||
response_body = json.loads(response["body"].read().decode())
|
||||
return response_body.get("completion", "")
|
||||
|
||||
def _prepare_input(
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Dict[str, Any]:
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepares the input dictionary for the specified provider's model by mapping and renaming
|
||||
keys in the input based on the provider's requirements.
|
||||
Prepares the input dictionary for the specified provider's model by mapping and renaming
|
||||
keys in the input based on the provider's requirements.
|
||||
|
||||
Args:
|
||||
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
|
||||
model (str): The name or identifier of the model being used.
|
||||
prompt (str): The text prompt to be processed by the model.
|
||||
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
|
||||
Args:
|
||||
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
|
||||
model (str): The name or identifier of the model being used.
|
||||
prompt (str): The text prompt to be processed by the model.
|
||||
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
|
||||
Returns:
|
||||
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
|
||||
"""
|
||||
|
||||
input_body = {"prompt": prompt, **model_kwargs}
|
||||
|
||||
|
||||
provider_mappings = {
|
||||
"meta": {"max_tokens_to_sample": "max_gen_len"},
|
||||
"ai21": {"max_tokens_to_sample": "maxTokens", "top_p": "topP"},
|
||||
"mistral": {"max_tokens_to_sample": "max_tokens"},
|
||||
"cohere": {"max_tokens_to_sample": "max_tokens", "top_p": "p"},
|
||||
}
|
||||
|
||||
|
||||
if provider in provider_mappings:
|
||||
for old_key, new_key in provider_mappings[provider].items():
|
||||
if old_key in input_body:
|
||||
input_body[new_key] = input_body.pop(old_key)
|
||||
|
||||
|
||||
if provider == "cohere" and "cohere.command-r" in model:
|
||||
input_body["message"] = input_body.pop("prompt")
|
||||
|
||||
|
||||
if provider == "amazon":
|
||||
input_body = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": {
|
||||
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
|
||||
"topP": model_kwargs.get("top_p"),
|
||||
"temperature": model_kwargs.get("temperature")
|
||||
}
|
||||
"temperature": model_kwargs.get("temperature"),
|
||||
},
|
||||
}
|
||||
input_body["textGenerationConfig"] = {k: v for k, v in input_body["textGenerationConfig"].items() if v is not None}
|
||||
|
||||
input_body["textGenerationConfig"] = {
|
||||
k: v
|
||||
for k, v in input_body["textGenerationConfig"].items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
return input_body
|
||||
|
||||
def _convert_tool_format(self, original_tools):
|
||||
@@ -133,32 +145,34 @@ class AWSBedrockLLM(LLMBase):
|
||||
list: A list of dictionaries representing the tools in the new standardized format.
|
||||
"""
|
||||
new_tools = []
|
||||
|
||||
|
||||
for tool in original_tools:
|
||||
if tool['type'] == 'function':
|
||||
function = tool['function']
|
||||
if tool["type"] == "function":
|
||||
function = tool["function"]
|
||||
new_tool = {
|
||||
"toolSpec": {
|
||||
"name": function['name'],
|
||||
"description": function['description'],
|
||||
"name": function["name"],
|
||||
"description": function["description"],
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": function['parameters'].get('required', [])
|
||||
"required": function["parameters"].get("required", []),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
for prop, details in function['parameters'].get('properties', {}).items():
|
||||
|
||||
for prop, details in (
|
||||
function["parameters"].get("properties", {}).items()
|
||||
):
|
||||
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
|
||||
"type": details.get('type', 'string'),
|
||||
"description": details.get('description', '')
|
||||
"type": details.get("type", "string"),
|
||||
"description": details.get("description", ""),
|
||||
}
|
||||
|
||||
|
||||
new_tools.append(new_tool)
|
||||
|
||||
|
||||
return new_tools
|
||||
|
||||
def generate_response(
|
||||
@@ -181,28 +195,39 @@ class AWSBedrockLLM(LLMBase):
|
||||
|
||||
if tools:
|
||||
# Use converse method when tools are provided
|
||||
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
|
||||
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": message["content"]} for message in messages],
|
||||
}
|
||||
]
|
||||
inference_config = {
|
||||
"temperature": self.model_kwargs["temperature"],
|
||||
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
|
||||
"topP": self.model_kwargs["top_p"],
|
||||
}
|
||||
tools_config = {"tools": self._convert_tool_format(tools)}
|
||||
|
||||
response = self.client.converse(
|
||||
modelId=self.config.model,
|
||||
messages=messages,
|
||||
inferenceConfig=inference_config,
|
||||
toolConfig=tools_config
|
||||
toolConfig=tools_config,
|
||||
)
|
||||
else:
|
||||
# Use invoke_model method when no tools are provided
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.model.split(".")[0]
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||
input_body = self._prepare_input(
|
||||
provider, self.config.model, prompt, **self.model_kwargs
|
||||
)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.model,
|
||||
accept='application/json',
|
||||
contentType='application/json'
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -6,13 +6,14 @@ from openai import AzureOpenAI
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class AzureOpenAILLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
if not self.config.model:
|
||||
self.config.model="gpt-4o"
|
||||
self.config.model = "gpt-4o"
|
||||
self.client = AzureOpenAI()
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
@@ -29,21 +30,22 @@ class AzureOpenAILLM(LLMBase):
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": []
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
@@ -64,11 +66,11 @@ class AzureOpenAILLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
|
||||
@@ -14,8 +14,15 @@ class LlmConfig(BaseModel):
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
provider = values.data.get("provider")
|
||||
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "azure_openai"):
|
||||
if provider in (
|
||||
"openai",
|
||||
"ollama",
|
||||
"groq",
|
||||
"together",
|
||||
"aws_bedrock",
|
||||
"litellm",
|
||||
"azure_openai",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
@@ -4,7 +4,9 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
from groq import Groq
|
||||
except ImportError:
|
||||
raise ImportError("Groq requires extra dependencies. Install with `pip install groq`") from None
|
||||
raise ImportError(
|
||||
"Groq requires extra dependencies. Install with `pip install groq`"
|
||||
) from None
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
@@ -15,7 +17,7 @@ class GroqLLM(LLMBase):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="llama3-70b-8192"
|
||||
self.config.model = "llama3-70b-8192"
|
||||
self.client = Groq()
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
@@ -32,16 +34,18 @@ class GroqLLM(LLMBase):
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": []
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
@@ -66,11 +70,11 @@ class GroqLLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"litellm requires extra dependencies. Install with `pip install litellm`"
|
||||
) from None
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
@@ -12,8 +17,8 @@ class LiteLLM(LLMBase):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="gpt-4o"
|
||||
|
||||
self.config.model = "gpt-4o"
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -28,16 +33,18 @@ class LiteLLM(LLMBase):
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": []
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
@@ -62,14 +69,16 @@ class LiteLLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
if not litellm.supports_function_calling(self.config.model):
|
||||
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||
raise ValueError(
|
||||
f"Model '{self.config.model}' in litellm does not support function calling."
|
||||
)
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
|
||||
@@ -3,28 +3,31 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
from ollama import Client
|
||||
except ImportError:
|
||||
raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None
|
||||
raise ImportError(
|
||||
"Ollama requires extra dependencies. Install with `pip install ollama`"
|
||||
) from None
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class OllamaLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="llama3.1:70b"
|
||||
self.config.model = "llama3.1:70b"
|
||||
self.client = Client(host=self.config.ollama_base_url)
|
||||
self._ensure_model_exists()
|
||||
|
||||
def _ensure_model_exists(self):
|
||||
"""
|
||||
Ensure the specified model exists locally. If not, pull it from Ollama.
|
||||
"""
|
||||
"""
|
||||
local_models = self.client.list()["models"]
|
||||
if not any(model.get("name") == self.config.model for model in local_models):
|
||||
self.client.pull(self.config.model)
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -38,20 +41,22 @@ class OllamaLLM(LLMBase):
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response['message']['content'],
|
||||
"tool_calls": []
|
||||
"content": response["message"]["content"],
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
if response['message'].get('tool_calls'):
|
||||
for tool_call in response['message']['tool_calls']:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"]
|
||||
})
|
||||
|
||||
|
||||
if response["message"].get("tool_calls"):
|
||||
for tool_call in response["message"]["tool_calls"]:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response['message']['content']
|
||||
return response["message"]["content"]
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
@@ -73,13 +78,13 @@ class OllamaLLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"options": {
|
||||
"temperature": self.config.temperature,
|
||||
"num_predict": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
}
|
||||
"temperature": self.config.temperature,
|
||||
"num_predict": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
},
|
||||
}
|
||||
if response_format:
|
||||
params["format"] = response_format
|
||||
@@ -87,4 +92,4 @@ class OllamaLLM(LLMBase):
|
||||
params["tools"] = tools
|
||||
|
||||
response = self.client.chat(**params)
|
||||
return self._parse_response(response, tools)
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -7,19 +7,23 @@ from openai import OpenAI
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class OpenAILLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="gpt-4o"
|
||||
self.config.model = "gpt-4o"
|
||||
|
||||
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
|
||||
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||
self.client = OpenAI(
|
||||
api_key=os.environ.get("OPENROUTER_API_KEY"),
|
||||
base_url=self.config.openrouter_base_url,
|
||||
)
|
||||
else:
|
||||
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -34,16 +38,18 @@ class OpenAILLM(LLMBase):
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": []
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
@@ -68,11 +74,11 @@ class OpenAILLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if os.getenv("OPENROUTER_API_KEY"):
|
||||
@@ -81,14 +87,14 @@ class OpenAILLM(LLMBase):
|
||||
openrouter_params["models"] = self.config.models
|
||||
openrouter_params["route"] = self.config.route
|
||||
params.pop("model")
|
||||
|
||||
|
||||
if self.config.site_url and self.config.app_name:
|
||||
extra_headers={
|
||||
"HTTP-Referer": self.config.site_url,
|
||||
"X-Title": self.config.app_name
|
||||
}
|
||||
openrouter_params["extra_headers"] = extra_headers
|
||||
|
||||
extra_headers = {
|
||||
"HTTP-Referer": self.config.site_url,
|
||||
"X-Title": self.config.app_name,
|
||||
}
|
||||
openrouter_params["extra_headers"] = extra_headers
|
||||
|
||||
params.update(**openrouter_params)
|
||||
|
||||
if response_format:
|
||||
|
||||
@@ -4,19 +4,22 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
from together import Together
|
||||
except ImportError:
|
||||
raise ImportError("Together requires extra dependencies. Install with `pip install together`") from None
|
||||
raise ImportError(
|
||||
"Together requires extra dependencies. Install with `pip install together`"
|
||||
) from None
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
|
||||
|
||||
class TogetherLLM(LLMBase):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config)
|
||||
|
||||
if not self.config.model:
|
||||
self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
self.client = Together()
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -31,16 +34,18 @@ class TogetherLLM(LLMBase):
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
"tool_calls": []
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
processed_response["tool_calls"].append({
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments)
|
||||
})
|
||||
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": tool_call.function.name,
|
||||
"arguments": json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
@@ -65,11 +70,11 @@ class TogetherLLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
|
||||
Reference in New Issue
Block a user