Fix not working with Gemini models (#2021)
This commit is contained in:
@@ -21,6 +21,7 @@ class LlmConfig(BaseModel):
|
||||
"azure_openai",
|
||||
"openai_structured",
|
||||
"azure_openai_structured",
|
||||
"gemini",
|
||||
):
|
||||
return v
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai import GenerativeModel, protos
|
||||
from google.generativeai.types import content_types
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -38,17 +39,24 @@ class GeminiLLM(LLMBase):
|
||||
"""
|
||||
if tools:
|
||||
processed_response = {
|
||||
"content": content if (content := response.candidates[0].content.parts[0].text) else None,
|
||||
"content": (
|
||||
content
|
||||
if (content := response.candidates[0].content.parts[0].text)
|
||||
else None
|
||||
),
|
||||
"tool_calls": [],
|
||||
}
|
||||
|
||||
for part in response.candidates[0].content.parts:
|
||||
if fn := part.function_call:
|
||||
if isinstance(fn, protos.FunctionCall):
|
||||
fn_call = type(fn).to_dict(fn)
|
||||
processed_response["tool_calls"].append(
|
||||
{"name": fn_call["name"], "arguments": fn_call["args"]}
|
||||
)
|
||||
continue
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": fn.name,
|
||||
"arguments": {key: val for key, val in fn.args.items()},
|
||||
}
|
||||
{"name": fn.name, "arguments": fn.args}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
@@ -69,12 +77,19 @@ class GeminiLLM(LLMBase):
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
content = (
|
||||
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
|
||||
)
|
||||
|
||||
else:
|
||||
content = message["content"]
|
||||
|
||||
new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"})
|
||||
new_messages.append(
|
||||
{
|
||||
"parts": content,
|
||||
"role": "model" if message["role"] == "model" else "user",
|
||||
}
|
||||
)
|
||||
|
||||
return new_messages
|
||||
|
||||
@@ -106,7 +121,12 @@ class GeminiLLM(LLMBase):
|
||||
if tools:
|
||||
for tool in tools:
|
||||
func = tool["function"].copy()
|
||||
new_tools.append({"function_declarations": [remove_additional_properties(func)]})
|
||||
new_tools.append(
|
||||
{"function_declarations": [remove_additional_properties(func)]}
|
||||
)
|
||||
|
||||
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
|
||||
# return content_types.to_function_library(new_tools)
|
||||
|
||||
return new_tools
|
||||
else:
|
||||
@@ -138,17 +158,20 @@ class GeminiLLM(LLMBase):
|
||||
"top_p": self.config.top_p,
|
||||
}
|
||||
|
||||
if response_format:
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
params["response_mime_type"] = "application/json"
|
||||
params["response_schema"] = list[response_format]
|
||||
if "schema" in response_format:
|
||||
params["response_schema"] = response_format["schema"]
|
||||
if tool_choice:
|
||||
tool_config = content_types.to_tool_config(
|
||||
{
|
||||
"function_calling_config": {
|
||||
"mode": tool_choice,
|
||||
"allowed_function_names": [tool["function"]["name"] for tool in tools]
|
||||
if tool_choice == "any"
|
||||
else None,
|
||||
"allowed_function_names": (
|
||||
[tool["function"]["name"] for tool in tools]
|
||||
if tool_choice == "any"
|
||||
else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user