[Misc] Lint code and fix code smells (#1871)
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
@@ -43,8 +43,8 @@ class AnthropicLLM(LLMBase):
|
||||
system_message = ""
|
||||
filtered_messages = []
|
||||
for message in messages:
|
||||
if message['role'] == 'system':
|
||||
system_message = message['content']
|
||||
if message["role"] == "system":
|
||||
system_message = message["content"]
|
||||
else:
|
||||
filtered_messages.append(message)
|
||||
|
||||
@@ -56,7 +56,7 @@ class AnthropicLLM(LLMBase):
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -125,9 +125,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
},
|
||||
}
|
||||
input_body["textGenerationConfig"] = {
|
||||
k: v
|
||||
for k, v in input_body["textGenerationConfig"].items()
|
||||
if v is not None
|
||||
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
|
||||
}
|
||||
|
||||
return input_body
|
||||
@@ -161,9 +159,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
}
|
||||
}
|
||||
|
||||
for prop, details in (
|
||||
function["parameters"].get("properties", {}).items()
|
||||
):
|
||||
for prop, details in function["parameters"].get("properties", {}).items():
|
||||
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
|
||||
"type": details.get("type", "string"),
|
||||
"description": details.get("description", ""),
|
||||
@@ -216,9 +212,7 @@ class AWSBedrockLLM(LLMBase):
|
||||
# Use invoke_model method when no tools are provided
|
||||
prompt = self._format_messages(messages)
|
||||
provider = self.model.split(".")[0]
|
||||
input_body = self._prepare_input(
|
||||
provider, self.config.model, prompt, **self.model_kwargs
|
||||
)
|
||||
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
response = self.client.invoke_model(
|
||||
|
||||
@@ -15,20 +15,20 @@ class AzureOpenAILLM(LLMBase):
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o"
|
||||
|
||||
|
||||
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
|
||||
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
|
||||
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
|
||||
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client
|
||||
)
|
||||
|
||||
http_client=self.config.http_client,
|
||||
)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -87,7 +87,7 @@ class AzureOpenAILLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import AzureOpenAI
|
||||
|
||||
from mem0.llms.base import LLMBase
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
|
||||
|
||||
class AzureOpenAIStructuredLLM(LLMBase):
|
||||
@@ -15,21 +15,21 @@ class AzureOpenAIStructuredLLM(LLMBase):
|
||||
# Model name should match the custom deployment name chosen for it.
|
||||
if not self.config.model:
|
||||
self.config.model = "gpt-4o-2024-08-06"
|
||||
|
||||
|
||||
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
|
||||
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
|
||||
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
|
||||
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
|
||||
# Can display a warning if API version is of model and api-version
|
||||
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
azure_deployment=azure_deployment,
|
||||
azure_deployment=azure_deployment,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
http_client=self.config.http_client
|
||||
)
|
||||
|
||||
http_client=self.config.http_client,
|
||||
)
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
|
||||
@@ -4,12 +4,8 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class LlmConfig(BaseModel):
|
||||
provider: str = Field(
|
||||
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
|
||||
)
|
||||
config: Optional[dict] = Field(
|
||||
description="Configuration for the specific LLM", default={}
|
||||
)
|
||||
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
|
||||
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
|
||||
|
||||
@field_validator("config")
|
||||
def validate_config(cls, v, values):
|
||||
@@ -23,7 +19,7 @@ class LlmConfig(BaseModel):
|
||||
"litellm",
|
||||
"azure_openai",
|
||||
"openai_structured",
|
||||
"azure_openai_structured"
|
||||
"azure_openai_structured",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
|
||||
@@ -67,9 +67,7 @@ class LiteLLM(LLMBase):
|
||||
str: The generated response.
|
||||
"""
|
||||
if not litellm.supports_function_calling(self.config.model):
|
||||
raise ValueError(
|
||||
f"Model '{self.config.model}' in litellm does not support function calling."
|
||||
)
|
||||
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
|
||||
|
||||
params = {
|
||||
"model": self.config.model,
|
||||
@@ -80,7 +78,7 @@ class LiteLLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ class OpenAILLM(LLMBase):
|
||||
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
@@ -20,7 +19,6 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE")
|
||||
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -31,8 +29,8 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
Returns:
|
||||
str or dict: The processed response.
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": response.choices[0].message.content,
|
||||
@@ -52,7 +50,6 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
else:
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
def generate_response(
|
||||
self,
|
||||
@@ -87,4 +84,4 @@ class OpenAIStructuredLLM(LLMBase):
|
||||
|
||||
response = self.client.beta.chat.completions.parse(**params)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -20,7 +20,7 @@ class TogetherLLM(LLMBase):
|
||||
|
||||
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
|
||||
self.client = Together(api_key=api_key)
|
||||
|
||||
|
||||
def _parse_response(self, response, tools):
|
||||
"""
|
||||
Process the response based on whether tools are used or not.
|
||||
@@ -79,7 +79,7 @@ class TogetherLLM(LLMBase):
|
||||
}
|
||||
if response_format:
|
||||
params["response_format"] = response_format
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
if tools: # TODO: Remove tools if no issues found with new memory addition logic
|
||||
params["tools"] = tools
|
||||
params["tool_choice"] = tool_choice
|
||||
|
||||
|
||||
@@ -7,11 +7,9 @@ ADD_MEMORY_TOOL = {
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "string", "description": "Data to add to memory"}
|
||||
},
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -34,7 +32,7 @@ UPDATE_MEMORY_TOOL = {
|
||||
},
|
||||
},
|
||||
"required": ["memory_id", "data"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -53,7 +51,7 @@ DELETE_MEMORY_TOOL = {
|
||||
}
|
||||
},
|
||||
"required": ["memory_id"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user