[Mem0] Update dependencies and make the package lighter (#1708)

Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
Deshraj Yadav
2024-08-14 23:28:07 -07:00
committed by GitHub
parent e35786e567
commit a8ba7abb7d
35 changed files with 634 additions and 1594 deletions

View File

@@ -5,22 +5,30 @@ from typing import Dict, List, Optional, Any
try:
import boto3
except ImportError:
raise ImportError("AWS Bedrock requires extra dependencies. Install with `pip install boto3`") from None
raise ImportError(
"AWS Bedrock requires extra dependencies. Install with `pip install boto3`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AWSBedrockLLM(LLMBase):
class AWSBedrockLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client("bedrock-runtime", region_name=os.environ.get("AWS_REGION"), aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"), aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"))
self.config.model = "anthropic.claude-3-5-sonnet-20240620-v1:0"
self.client = boto3.client(
"bedrock-runtime",
region_name=os.environ.get("AWS_REGION"),
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
)
self.model_kwargs = {
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
"top_p": self.config.top_p
"temperature": self.config.temperature,
"max_tokens_to_sample": self.config.max_tokens,
"top_p": self.config.top_p,
}
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
@@ -28,7 +36,7 @@ class AWSBedrockLLM(LLMBase):
Formats a list of messages into the required prompt structure for the model.
Args:
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
messages (List[Dict[str, str]]): A list of dictionaries where each dictionary represents a message.
Each dictionary contains 'role' and 'content' keys.
Returns:
@@ -36,12 +44,12 @@ class AWSBedrockLLM(LLMBase):
"""
formatted_messages = []
for message in messages:
role = message['role'].capitalize()
content = message['content']
role = message["role"].capitalize()
content = message["content"]
formatted_messages.append(f"\n\n{role}: {content}")
return "".join(formatted_messages) + "\n\nAssistant:"
def _parse_response(self, response, tools) -> str:
"""
Process the response based on whether tools are used or not.
@@ -54,72 +62,76 @@ class AWSBedrockLLM(LLMBase):
str or dict: The processed response.
"""
if tools:
processed_response = {
"tool_calls": []
}
processed_response = {"tool_calls": []}
if response["output"]["message"]["content"]:
for item in response["output"]["message"]["content"]:
if "toolUse" in item:
processed_response["tool_calls"].append({
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"]
})
processed_response["tool_calls"].append(
{
"name": item["toolUse"]["name"],
"arguments": item["toolUse"]["input"],
}
)
return processed_response
response_body = json.loads(response['body'].read().decode())
return response_body.get('completion', '')
response_body = json.loads(response["body"].read().decode())
return response_body.get("completion", "")
def _prepare_input(
self,
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
self,
provider: str,
model: str,
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements.
Prepares the input dictionary for the specified provider's model by mapping and renaming
keys in the input based on the provider's requirements.
Args:
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Args:
provider (str): The name of the service provider (e.g., "meta", "ai21", "mistral", "cohere", "amazon").
model (str): The name or identifier of the model being used.
prompt (str): The text prompt to be processed by the model.
model_kwargs (Dict[str, Any]): Additional keyword arguments specific to the model's requirements.
Returns:
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
Returns:
Dict[str, Any]: The prepared input dictionary with the correct keys and values for the specified provider.
"""
input_body = {"prompt": prompt, **model_kwargs}
provider_mappings = {
"meta": {"max_tokens_to_sample": "max_gen_len"},
"ai21": {"max_tokens_to_sample": "maxTokens", "top_p": "topP"},
"mistral": {"max_tokens_to_sample": "max_tokens"},
"cohere": {"max_tokens_to_sample": "max_tokens", "top_p": "p"},
}
if provider in provider_mappings:
for old_key, new_key in provider_mappings[provider].items():
if old_key in input_body:
input_body[new_key] = input_body.pop(old_key)
if provider == "cohere" and "cohere.command-r" in model:
input_body["message"] = input_body.pop("prompt")
if provider == "amazon":
input_body = {
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": model_kwargs.get("max_tokens_to_sample"),
"topP": model_kwargs.get("top_p"),
"temperature": model_kwargs.get("temperature")
}
"temperature": model_kwargs.get("temperature"),
},
}
input_body["textGenerationConfig"] = {k: v for k, v in input_body["textGenerationConfig"].items() if v is not None}
input_body["textGenerationConfig"] = {
k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
}
return input_body
def _convert_tool_format(self, original_tools):
@@ -133,32 +145,34 @@ class AWSBedrockLLM(LLMBase):
list: A list of dictionaries representing the tools in the new standardized format.
"""
new_tools = []
for tool in original_tools:
if tool['type'] == 'function':
function = tool['function']
if tool["type"] == "function":
function = tool["function"]
new_tool = {
"toolSpec": {
"name": function['name'],
"description": function['description'],
"name": function["name"],
"description": function["description"],
"inputSchema": {
"json": {
"type": "object",
"properties": {},
"required": function['parameters'].get('required', [])
"required": function["parameters"].get("required", []),
}
}
},
}
}
for prop, details in function['parameters'].get('properties', {}).items():
for prop, details in (
function["parameters"].get("properties", {}).items()
):
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get('type', 'string'),
"description": details.get('description', '')
"type": details.get("type", "string"),
"description": details.get("description", ""),
}
new_tools.append(new_tool)
return new_tools
def generate_response(
@@ -181,28 +195,39 @@ class AWSBedrockLLM(LLMBase):
if tools:
# Use converse method when tools are provided
messages = [{"role": "user", "content": [{"text": message["content"]} for message in messages]}]
inference_config = {"temperature": self.model_kwargs["temperature"], "maxTokens": self.model_kwargs["max_tokens_to_sample"], "topP": self.model_kwargs["top_p"]}
messages = [
{
"role": "user",
"content": [{"text": message["content"]} for message in messages],
}
]
inference_config = {
"temperature": self.model_kwargs["temperature"],
"maxTokens": self.model_kwargs["max_tokens_to_sample"],
"topP": self.model_kwargs["top_p"],
}
tools_config = {"tools": self._convert_tool_format(tools)}
response = self.client.converse(
modelId=self.config.model,
messages=messages,
inferenceConfig=inference_config,
toolConfig=tools_config
toolConfig=tools_config,
)
else:
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
input_body = self._prepare_input(
provider, self.config.model, prompt, **self.model_kwargs
)
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=self.model,
accept='application/json',
contentType='application/json'
accept="application/json",
contentType="application/json",
)
return self._parse_response(response, tools)

View File

@@ -6,13 +6,14 @@ from openai import AzureOpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class AzureOpenAILLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model="gpt-4o"
self.config.model = "gpt-4o"
self.client = AzureOpenAI()
def _parse_response(self, response, tools):
@@ -29,21 +30,22 @@ class AzureOpenAILLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
def generate_response(
self,
messages: List[Dict[str, str]],
@@ -64,11 +66,11 @@ class AzureOpenAILLM(LLMBase):
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -14,8 +14,15 @@ class LlmConfig(BaseModel):
@field_validator("config")
def validate_config(cls, v, values):
provider = values.data.get("provider")
if provider in ("openai", "ollama", "groq", "together", "aws_bedrock", "litellm", "azure_openai"):
if provider in (
"openai",
"ollama",
"groq",
"together",
"aws_bedrock",
"litellm",
"azure_openai",
):
return v
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View File

@@ -4,7 +4,9 @@ from typing import Dict, List, Optional
try:
from groq import Groq
except ImportError:
raise ImportError("Groq requires extra dependencies. Install with `pip install groq`") from None
raise ImportError(
"Groq requires extra dependencies. Install with `pip install groq`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
@@ -15,7 +17,7 @@ class GroqLLM(LLMBase):
super().__init__(config)
if not self.config.model:
self.config.model="llama3-70b-8192"
self.config.model = "llama3-70b-8192"
self.client = Groq()
def _parse_response(self, response, tools):
@@ -32,16 +34,18 @@ class GroqLLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
@@ -66,11 +70,11 @@ class GroqLLM(LLMBase):
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -1,7 +1,12 @@
import json
from typing import Dict, List, Optional
import litellm
try:
import litellm
except ImportError:
raise ImportError(
"litellm requires extra dependencies. Install with `pip install litellm`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
@@ -12,8 +17,8 @@ class LiteLLM(LLMBase):
super().__init__(config)
if not self.config.model:
self.config.model="gpt-4o"
self.config.model = "gpt-4o"
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -28,16 +33,18 @@ class LiteLLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
@@ -62,14 +69,16 @@ class LiteLLM(LLMBase):
str: The generated response.
"""
if not litellm.supports_function_calling(self.config.model):
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
raise ValueError(
f"Model '{self.config.model}' in litellm does not support function calling."
)
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format

View File

@@ -3,28 +3,31 @@ from typing import Dict, List, Optional
try:
from ollama import Client
except ImportError:
raise ImportError("Ollama requires extra dependencies. Install with `pip install ollama`") from None
raise ImportError(
"Ollama requires extra dependencies. Install with `pip install ollama`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class OllamaLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="llama3.1:70b"
self.config.model = "llama3.1:70b"
self.client = Client(host=self.config.ollama_base_url)
self._ensure_model_exists()
def _ensure_model_exists(self):
"""
Ensure the specified model exists locally. If not, pull it from Ollama.
"""
"""
local_models = self.client.list()["models"]
if not any(model.get("name") == self.config.model for model in local_models):
self.client.pull(self.config.model)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -38,20 +41,22 @@ class OllamaLLM(LLMBase):
"""
if tools:
processed_response = {
"content": response['message']['content'],
"tool_calls": []
"content": response["message"]["content"],
"tool_calls": [],
}
if response['message'].get('tool_calls'):
for tool_call in response['message']['tool_calls']:
processed_response["tool_calls"].append({
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"]
})
if response["message"].get("tool_calls"):
for tool_call in response["message"]["tool_calls"]:
processed_response["tool_calls"].append(
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
)
return processed_response
else:
return response['message']['content']
return response["message"]["content"]
def generate_response(
self,
@@ -73,13 +78,13 @@ class OllamaLLM(LLMBase):
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"model": self.config.model,
"messages": messages,
"options": {
"temperature": self.config.temperature,
"num_predict": self.config.max_tokens,
"top_p": self.config.top_p
}
"temperature": self.config.temperature,
"num_predict": self.config.max_tokens,
"top_p": self.config.top_p,
},
}
if response_format:
params["format"] = response_format
@@ -87,4 +92,4 @@ class OllamaLLM(LLMBase):
params["tools"] = tools
response = self.client.chat(**params)
return self._parse_response(response, tools)
return self._parse_response(response, tools)

View File

@@ -7,19 +7,23 @@ from openai import OpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class OpenAILLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="gpt-4o"
self.config.model = "gpt-4o"
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(api_key=os.environ.get("OPENROUTER_API_KEY"), base_url=self.config.openrouter_base_url)
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url,
)
else:
api_key = os.getenv("OPENAI_API_KEY") or self.config.api_key
self.client = OpenAI(api_key=api_key)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -34,16 +38,18 @@ class OpenAILLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
@@ -68,11 +74,11 @@ class OpenAILLM(LLMBase):
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if os.getenv("OPENROUTER_API_KEY"):
@@ -81,14 +87,14 @@ class OpenAILLM(LLMBase):
openrouter_params["models"] = self.config.models
openrouter_params["route"] = self.config.route
params.pop("model")
if self.config.site_url and self.config.app_name:
extra_headers={
"HTTP-Referer": self.config.site_url,
"X-Title": self.config.app_name
}
openrouter_params["extra_headers"] = extra_headers
extra_headers = {
"HTTP-Referer": self.config.site_url,
"X-Title": self.config.app_name,
}
openrouter_params["extra_headers"] = extra_headers
params.update(**openrouter_params)
if response_format:

View File

@@ -4,19 +4,22 @@ from typing import Dict, List, Optional
try:
from together import Together
except ImportError:
raise ImportError("Together requires extra dependencies. Install with `pip install together`") from None
raise ImportError(
"Together requires extra dependencies. Install with `pip install together`"
) from None
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
class TogetherLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)
if not self.config.model:
self.config.model="mistralai/Mixtral-8x7B-Instruct-v0.1"
self.config.model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
self.client = Together()
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -31,16 +34,18 @@ class TogetherLLM(LLMBase):
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": []
"tool_calls": [],
}
if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append({
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments)
})
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(tool_call.function.arguments),
}
)
return processed_response
else:
return response.choices[0].message.content
@@ -65,11 +70,11 @@ class TogetherLLM(LLMBase):
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format