Support async client (#1980)

This commit is contained in:
Dev Khant
2024-10-22 12:42:55 +05:30
committed by GitHub
parent c5d298eec8
commit fbf1d8c372
11 changed files with 213 additions and 58 deletions

View File

@@ -6,7 +6,9 @@ try:
from google.generativeai import GenerativeModel
from google.generativeai.types import content_types
except ImportError:
raise ImportError("The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'.")
raise ImportError(
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
)
from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
@@ -44,8 +46,8 @@ class GeminiLLM(LLMBase):
if fn := part.function_call:
processed_response["tool_calls"].append(
{
"name": fn.name,
"arguments": {key:val for key, val in fn.args.items()},
"name": fn.name,
"arguments": {key: val for key, val in fn.args.items()},
}
)
@@ -53,7 +55,7 @@ class GeminiLLM(LLMBase):
else:
return response.candidates[0].content.parts[0].text
def _reformat_messages(self, messages : List[Dict[str, str]]):
def _reformat_messages(self, messages: List[Dict[str, str]]):
"""
Reformat messages for Gemini.
@@ -71,9 +73,8 @@ class GeminiLLM(LLMBase):
else:
content = message["content"]
new_messages.append({"parts": content,
"role": "model" if message["role"] == "model" else "user"})
new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})
return new_messages
@@ -89,24 +90,24 @@ class GeminiLLM(LLMBase):
"""
def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
"""Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict):
filtered_dict = {
key: remove_additional_properties(value)
for key, value in data.items()
if not (key == "additionalProperties")
}
return filtered_dict
else:
return data
new_tools = []
if tools:
for tool in tools:
func = tool['function'].copy()
new_tools.append({"function_declarations":[remove_additional_properties(func)]})
func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
return new_tools
else:
return None
@@ -142,13 +143,21 @@ class GeminiLLM(LLMBase):
params["response_schema"] = list[response_format]
if tool_choice:
tool_config = content_types.to_tool_config(
{"function_calling_config":
{"mode": tool_choice, "allowed_function_names": [tool['function']['name'] for tool in tools] if tool_choice == "any" else None}
})
{
"function_calling_config": {
"mode": tool_choice,
"allowed_function_names": [tool["function"]["name"] for tool in tools]
if tool_choice == "any"
else None,
}
}
)
response = self.client.generate_content(contents = self._reformat_messages(messages),
tools = self._reformat_tools(tools),
generation_config = genai.GenerationConfig(**params),
tool_config = tool_config)
response = self.client.generate_content(
contents=self._reformat_messages(messages),
tools=self._reformat_tools(tools),
generation_config=genai.GenerationConfig(**params),
tool_config=tool_config,
)
return self._parse_response(response, tools)

View File

@@ -18,7 +18,9 @@ class OpenAILLM(LLMBase):
if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter
self.client = OpenAI(
api_key=os.environ.get("OPENROUTER_API_KEY"),
base_url=self.config.openrouter_base_url or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1",
base_url=self.config.openrouter_base_url
or os.getenv("OPENROUTER_API_BASE")
or "https://openrouter.ai/api/v1",
)
else:
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")