[Misc] Lint code and fix code smells (#1871)

This commit is contained in:
Deshraj Yadav
2024-09-16 17:39:54 -07:00
committed by GitHub
parent 0a78cb9f7a
commit 55c54beeab
57 changed files with 1178 additions and 1357 deletions

View File

@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
try:
import anthropic
except ImportError:
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
raise ImportError("The 'anthropic' library is required. Please install it using 'pip install anthropic'.")
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
@@ -43,8 +43,8 @@ class AnthropicLLM(LLMBase):
system_message = ""
filtered_messages = []
for message in messages:
if message['role'] == 'system':
system_message = message['content']
if message["role"] == "system":
system_message = message["content"]
else:
filtered_messages.append(message)
@@ -56,7 +56,7 @@ class AnthropicLLM(LLMBase):
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if tools: # TODO: Remove tools if no issues found with new memory addition logic
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

View File

@@ -125,9 +125,7 @@ class AWSBedrockLLM(LLMBase):
},
}
input_body["textGenerationConfig"] = {
k: v
for k, v in input_body["textGenerationConfig"].items()
if v is not None
k: v for k, v in input_body["textGenerationConfig"].items() if v is not None
}
return input_body
@@ -161,9 +159,7 @@ class AWSBedrockLLM(LLMBase):
}
}
for prop, details in (
function["parameters"].get("properties", {}).items()
):
for prop, details in function["parameters"].get("properties", {}).items():
new_tool["toolSpec"]["inputSchema"]["json"]["properties"][prop] = {
"type": details.get("type", "string"),
"description": details.get("description", ""),
@@ -216,9 +212,7 @@ class AWSBedrockLLM(LLMBase):
# Use invoke_model method when no tools are provided
prompt = self._format_messages(messages)
provider = self.model.split(".")[0]
input_body = self._prepare_input(
provider, self.config.model, prompt, **self.model_kwargs
)
input_body = self._prepare_input(provider, self.config.model, prompt, **self.model_kwargs)
body = json.dumps(input_body)
response = self.client.invoke_model(

View File

@@ -15,20 +15,20 @@ class AzureOpenAILLM(LLMBase):
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model = "gpt-4o"
api_key = self.config.azure_kwargs.api_key or os.getenv("LLM_AZURE_OPENAI_API_KEY")
azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT")
azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT")
api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION")
self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client
)
http_client=self.config.http_client,
)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -87,7 +87,7 @@ class AzureOpenAILLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

View File

@@ -1,11 +1,11 @@
import os
import json
import os
from typing import Dict, List, Optional
from openai import AzureOpenAI
from mem0.llms.base import LLMBase
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
class AzureOpenAIStructuredLLM(LLMBase):
@@ -15,21 +15,21 @@ class AzureOpenAIStructuredLLM(LLMBase):
# Model name should match the custom deployment name chosen for it.
if not self.config.model:
self.config.model = "gpt-4o-2024-08-06"
api_key = os.getenv("LLM_AZURE_OPENAI_API_KEY") or self.config.azure_kwargs.api_key
azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment
azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint
api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version
# Can display a warning if API version is of model and api-version
self.client = AzureOpenAI(
azure_deployment=azure_deployment,
azure_deployment=azure_deployment,
azure_endpoint=azure_endpoint,
api_version=api_version,
api_key=api_key,
http_client=self.config.http_client
)
http_client=self.config.http_client,
)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.

View File

@@ -4,12 +4,8 @@ from pydantic import BaseModel, Field, field_validator
class LlmConfig(BaseModel):
provider: str = Field(
description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai"
)
config: Optional[dict] = Field(
description="Configuration for the specific LLM", default={}
)
provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai")
config: Optional[dict] = Field(description="Configuration for the specific LLM", default={})
@field_validator("config")
def validate_config(cls, v, values):
@@ -23,7 +19,7 @@ class LlmConfig(BaseModel):
"litellm",
"azure_openai",
"openai_structured",
"azure_openai_structured"
"azure_openai_structured",
):
return v
else:

View File

@@ -67,9 +67,7 @@ class LiteLLM(LLMBase):
str: The generated response.
"""
if not litellm.supports_function_calling(self.config.model):
raise ValueError(
f"Model '{self.config.model}' in litellm does not support function calling."
)
raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
params = {
"model": self.config.model,
@@ -80,7 +78,7 @@ class LiteLLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

View File

@@ -100,7 +100,7 @@ class OpenAILLM(LLMBase):
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

View File

@@ -1,6 +1,5 @@
import os
import json
import os
from typing import Dict, List, Optional
from openai import OpenAI
@@ -20,7 +19,6 @@ class OpenAIStructuredLLM(LLMBase):
base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE")
self.client = OpenAI(api_key=api_key, base_url=base_url)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -31,8 +29,8 @@ class OpenAIStructuredLLM(LLMBase):
Returns:
str or dict: The processed response.
"""
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
@@ -52,7 +50,6 @@ class OpenAIStructuredLLM(LLMBase):
else:
return response.choices[0].message.content
def generate_response(
self,
@@ -87,4 +84,4 @@ class OpenAIStructuredLLM(LLMBase):
response = self.client.beta.chat.completions.parse(**params)
return self._parse_response(response, tools)
return self._parse_response(response, tools)

View File

@@ -20,7 +20,7 @@ class TogetherLLM(LLMBase):
api_key = self.config.api_key or os.getenv("TOGETHER_API_KEY")
self.client = Together(api_key=api_key)
def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.
@@ -79,7 +79,7 @@ class TogetherLLM(LLMBase):
}
if response_format:
params["response_format"] = response_format
if tools: # TODO: Remove tools if no issues found with new memory addition logic
if tools: # TODO: Remove tools if no issues found with new memory addition logic
params["tools"] = tools
params["tool_choice"] = tool_choice

View File

@@ -7,11 +7,9 @@ ADD_MEMORY_TOOL = {
"description": "Add a memory",
"parameters": {
"type": "object",
"properties": {
"data": {"type": "string", "description": "Data to add to memory"}
},
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
"required": ["data"],
"additionalProperties": False
"additionalProperties": False,
},
},
}
@@ -34,7 +32,7 @@ UPDATE_MEMORY_TOOL = {
},
},
"required": ["memory_id", "data"],
"additionalProperties": False
"additionalProperties": False,
},
},
}
@@ -53,7 +51,7 @@ DELETE_MEMORY_TOOL = {
}
},
"required": ["memory_id"],
"additionalProperties": False
"additionalProperties": False,
},
},
}