Fix not working with Gemini models (#2021)

This commit is contained in:
Hieu Lam
2025-01-09 18:49:26 +07:00
committed by GitHub
parent c90f87e657
commit 4c31c65649
4 changed files with 63 additions and 17 deletions

View File

@@ -21,6 +21,7 @@ class LlmConfig(BaseModel):
"azure_openai", "azure_openai",
"openai_structured", "openai_structured",
"azure_openai_structured", "azure_openai_structured",
"gemini",
): ):
return v return v
else: else:

View File

@@ -1,9 +1,10 @@
import os import os
import json
from typing import Dict, List, Optional from typing import Dict, List, Optional
try: try:
import google.generativeai as genai import google.generativeai as genai
from google.generativeai import GenerativeModel from google.generativeai import GenerativeModel, protos
from google.generativeai.types import content_types from google.generativeai.types import content_types
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@@ -38,17 +39,24 @@ class GeminiLLM(LLMBase):
""" """
if tools: if tools:
processed_response = { processed_response = {
"content": content if (content := response.candidates[0].content.parts[0].text) else None, "content": (
content
if (content := response.candidates[0].content.parts[0].text)
else None
),
"tool_calls": [], "tool_calls": [],
} }
for part in response.candidates[0].content.parts: for part in response.candidates[0].content.parts:
if fn := part.function_call: if fn := part.function_call:
if isinstance(fn, protos.FunctionCall):
fn_call = type(fn).to_dict(fn)
processed_response["tool_calls"].append( processed_response["tool_calls"].append(
{ {"name": fn_call["name"], "arguments": fn_call["args"]}
"name": fn.name, )
"arguments": {key: val for key, val in fn.args.items()}, continue
} processed_response["tool_calls"].append(
{"name": fn.name, "arguments": fn.args}
) )
return processed_response return processed_response
@@ -69,12 +77,19 @@ class GeminiLLM(LLMBase):
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] content = (
"THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
)
else: else:
content = message["content"] content = message["content"]
new_messages.append({"parts": content, "role": "model" if message["role"] == "model" else "user"}) new_messages.append(
{
"parts": content,
"role": "model" if message["role"] == "model" else "user",
}
)
return new_messages return new_messages
@@ -106,7 +121,12 @@ class GeminiLLM(LLMBase):
if tools: if tools:
for tool in tools: for tool in tools:
func = tool["function"].copy() func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]}) new_tools.append(
{"function_declarations": [remove_additional_properties(func)]}
)
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later.
# return content_types.to_function_library(new_tools)
return new_tools return new_tools
else: else:
@@ -138,17 +158,20 @@ class GeminiLLM(LLMBase):
"top_p": self.config.top_p, "top_p": self.config.top_p,
} }
if response_format: if response_format is not None and response_format["type"] == "json_object":
params["response_mime_type"] = "application/json" params["response_mime_type"] = "application/json"
params["response_schema"] = list[response_format] if "schema" in response_format:
params["response_schema"] = response_format["schema"]
if tool_choice: if tool_choice:
tool_config = content_types.to_tool_config( tool_config = content_types.to_tool_config(
{ {
"function_calling_config": { "function_calling_config": {
"mode": tool_choice, "mode": tool_choice,
"allowed_function_names": [tool["function"]["name"] for tool in tools] "allowed_function_names": (
[tool["function"]["name"] for tool in tools]
if tool_choice == "any" if tool_choice == "any"
else None, else None
),
} }
} }
) )

View File

@@ -16,7 +16,11 @@ from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_fact_retrieval_messages, parse_messages from mem0.memory.utils import (
get_fact_retrieval_messages,
parse_messages,
remove_code_blocks,
)
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
# Setup user config # Setup user config
@@ -152,6 +156,7 @@ class Memory(MemoryBase):
) )
try: try:
response = remove_code_blocks(response)
new_retrieved_facts = json.loads(response)["facts"] new_retrieved_facts = json.loads(response)["facts"]
except Exception as e: except Exception as e:
logging.error(f"Error in new_retrieved_facts: {e}") logging.error(f"Error in new_retrieved_facts: {e}")
@@ -184,6 +189,8 @@ class Memory(MemoryBase):
messages=[{"role": "user", "content": function_calling_prompt}], messages=[{"role": "user", "content": function_calling_prompt}],
response_format={"type": "json_object"}, response_format={"type": "json_object"},
) )
new_memories_with_actions = remove_code_blocks(new_memories_with_actions)
new_memories_with_actions = json.loads(new_memories_with_actions) new_memories_with_actions = json.loads(new_memories_with_actions)
returned_memories = [] returned_memories = []

View File

@@ -1,3 +1,4 @@
import re
import json import json
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
@@ -28,3 +29,17 @@ def format_entities(entities):
formatted_lines.append(simplified) formatted_lines.append(simplified)
return "\n".join(formatted_lines) return "\n".join(formatted_lines)
def remove_code_blocks(content: str) -> str:
"""
Removes enclosing code block markers ```[language] and ``` from a given string.
Remarks:
- The function uses a regex pattern to match code blocks that may start with ``` followed by an optional language tag (letters or numbers) and end with ```.
- If a code block is detected, it returns only the inner content, stripping out the markers.
- If no code block markers are found, the original content is returned as-is.
"""
pattern = r"^```[a-zA-Z0-9]*\n([\s\S]*?)\n```$"
match = re.match(pattern, content.strip())
return match.group(1).strip() if match else content.strip()