Support async client (#1980)
This commit is contained in:
@@ -6,7 +6,9 @@ try:
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.types import content_types
|
||||
except ImportError:
|
||||
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.")
|
||||
raise ImportError(
|
||||
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
||||
)
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.base import LLMBase
|
||||
@@ -44,8 +46,8 @@ class GeminiLLM(LLMBase):
|
||||
if fn := part.function_call:
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": fn.name,
|
||||
"arguments": {key:val for key, val in fn.args.items()},
|
||||
"name": fn.name,
|
||||
"arguments": {key: val for key, val in fn.args.items()},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -53,7 +55,7 @@ class GeminiLLM(LLMBase):
|
||||
else:
|
||||
return response.candidates[0].content.parts[0].text
|
||||
|
||||
def _reformat_messages(self, messages : List[Dict[str, str]]):
|
||||
def _reformat_messages(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Reformat messages for Gemini.
|
||||
|
||||
@@ -71,9 +73,8 @@ class GeminiLLM(LLMBase):
|
||||
|
||||
else:
|
||||
content = message["content"]
|
||||
|
||||
new_messages.append({"parts": content,
|
||||
"role": "model" if message["role"] == "model" else "user"})
|
||||
|
||||
new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})
|
||||
|
||||
return new_messages
|
||||
|
||||
@@ -89,24 +90,24 @@ class GeminiLLM(LLMBase):
|
||||
"""
|
||||
|
||||
def remove_additional_properties(data):
|
||||
"""Recursively removes 'additionalProperties' from nested dictionaries."""
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {
|
||||
key: remove_additional_properties(value)
|
||||
for key, value in data.items()
|
||||
if not (key == "additionalProperties")
|
||||
}
|
||||
return filtered_dict
|
||||
else:
|
||||
return data
|
||||
|
||||
"""Recursively removes 'additionalProperties' from nested dictionaries."""
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {
|
||||
key: remove_additional_properties(value)
|
||||
for key, value in data.items()
|
||||
if not (key == "additionalProperties")
|
||||
}
|
||||
return filtered_dict
|
||||
else:
|
||||
return data
|
||||
|
||||
new_tools = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
func = tool['function'].copy()
|
||||
new_tools.append({"function_declarations":[remove_additional_properties(func)]})
|
||||
|
||||
func = tool["function"].copy()
|
||||
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||
|
||||
return new_tools
|
||||
else:
|
||||
return None
|
||||
@@ -142,13 +143,21 @@ class GeminiLLM(LLMBase):
|
||||
params["response_schema"] = list[response_format]
|
||||
if tool_choice:
|
||||
tool_config = content_types.to_tool_config(
|
||||
{"function_calling_config":
|
||||
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
|
||||
})
|
||||
{
|
||||
"function_calling_config": {
|
||||
"mode": tool_choice,
|
||||
"allowed_function_names": [tool["function"]["name"] for tool in tools]
|
||||
if tool_choice == "any"
|
||||
else None,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
response = self.client.generate_content(contents = self._reformat_messages(messages),
|
||||
tools = self._reformat_tools(tools),
|
||||
generation_config = genai.GenerationConfig(**params),
|
||||
tool_config = tool_config)
|
||||
response = self.client.generate_content(
|
||||
contents=self._reformat_messages(messages),
|
||||
tools=self._reformat_tools(tools),
|
||||
generation_config=genai.GenerationConfig(**params),
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -18,7 +18,9 @@ class OpenAILLM(LLMBase):
|
||||
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
|
||||
self.client = OpenAI(
|
||||
api_key=os.environ.get("OPENROUTER_API_KEY"),
|
||||
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1",
|
||||
base_url=self.config.openrouter_base_url
|
||||
or os.getenv("OPENROUTER_API_BASE")
|
||||
or "https://openrouter.ai/api/v1",
|
||||
)
|
||||
else:
|
||||
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
Reference in New Issue
Block a user