Fix: Gemini Embeddings and LLM (#3050)

This commit is contained in:
Dev Khant
2025-06-26 21:05:00 +05:30
committed by GitHub
parent acf7a30d32
commit e3e2da6d45
2 changed files with 99 additions and 74 deletions

View File

@@ -1,7 +1,8 @@
import os import os
from typing import Literal, Optional from typing import Literal, Optional
import google.genai as genai from google import genai
from google.genai import types
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
@@ -16,24 +17,23 @@ class GoogleGenAIEmbedding(EmbeddingBase):
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY") api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
if api_key: self.client = genai.Client(api_key=api_key)
self.client = genai.Client(api_key="api_key")
else:
self.client = genai.Client()
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using Google Generative AI. Get the embedding for the given text using Google Generative AI.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. (Currently not used by Gemini for task_type) memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
response = self.client.models.embed_content( # Create config for embedding parameters
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims config = types.EmbedContentConfig(output_dimensionality=self.config.embedding_dims)
)
return response["embedding"] # Call the embed_content method with the correct parameters
response = self.client.models.embed_content(model=self.config.model, contents=text, config=config)
return response.embeddings[0].values

View File

@@ -4,11 +4,8 @@ from typing import Dict, List, Optional
try: try:
from google import genai from google import genai
from google.genai import types from google.genai import types
except ImportError: except ImportError:
raise ImportError( raise ImportError("The 'google-genai' library is required. Please install it using 'pip install google-genai'.")
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
)
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase from mem0.llms.base import LLMBase
@@ -19,70 +16,79 @@ class GeminiLLM(LLMBase):
super().__init__(config) super().__init__(config)
if not self.config.model: if not self.config.model:
self.config.model = "gemini-1.5-flash-latest" self.config.model = "gemini-2.0-flash"
api_key = self.config.api_key or os.getenv("GEMINI_API_KEY") api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
self.client_gemini = genai.Client( self.client = genai.Client(api_key=api_key)
api_key=api_key,
)
def _parse_response(self, response, tools): def _parse_response(self, response, tools):
""" """
Process the response based on whether tools are used or not. Process the response based on whether tools are used or not.
Args: Args:
response: The raw response from the API. response: The raw response from API.
tools: The list of tools provided in the request. tools: The list of tools provided in the request.
Returns: Returns:
str or dict: The processed response. str or dict: The processed response.
""" """
candidate = response.candidates[0]
content = candidate.content.parts[0].text if candidate.content.parts else None
if tools: if tools:
processed_response = { processed_response = {
"content": content, "content": None,
"tool_calls": [], "tool_calls": [],
} }
for part in candidate.content.parts: # Extract content from the first candidate
fn = getattr(part, "function_call", None) if response.candidates and response.candidates[0].content.parts:
if fn: for part in response.candidates[0].content.parts:
processed_response["tool_calls"].append( if hasattr(part, "text") and part.text:
{ processed_response["content"] = part.text
"name": fn.name, break
"arguments": fn.args,
} # Extract function calls
) if response.candidates and response.candidates[0].content.parts:
for part in response.candidates[0].content.parts:
if hasattr(part, "function_call") and part.function_call:
fn = part.function_call
processed_response["tool_calls"].append(
{
"name": fn.name,
"arguments": dict(fn.args) if fn.args else {},
}
)
return processed_response return processed_response
else:
if response.candidates and response.candidates[0].content.parts:
for part in response.candidates[0].content.parts:
if hasattr(part, "text") and part.text:
return part.text
return ""
return content def _reformat_messages(self, messages: List[Dict[str, str]]):
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
""" """
Reformat messages for Gemini using google.genai.types. Reformat messages for Gemini.
Args: Args:
messages: The list of messages provided in the request. messages: The list of messages provided in the request.
Returns: Returns:
list: A list of types.Content objects with proper role and parts. tuple: (system_instruction, contents_list)
""" """
new_messages = [] system_instruction = None
contents = []
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] system_instruction = message["content"]
else: else:
content = message["content"] content = types.Content(
parts=[types.Part(text=message["content"])],
role=message["role"],
)
contents.append(content)
new_messages.append( return system_instruction, contents
types.Content(role="model" if message["role"] == "model" else "user", parts=[types.Part(text=content)])
)
return new_messages
def _reformat_tools(self, tools: Optional[List[Dict]]): def _reformat_tools(self, tools: Optional[List[Dict]]):
""" """
@@ -97,7 +103,6 @@ class GeminiLLM(LLMBase):
def remove_additional_properties(data): def remove_additional_properties(data):
"""Recursively removes 'additionalProperties' from nested dictionaries.""" """Recursively removes 'additionalProperties' from nested dictionaries."""
if isinstance(data, dict): if isinstance(data, dict):
filtered_dict = { filtered_dict = {
key: remove_additional_properties(value) key: remove_additional_properties(value)
@@ -108,16 +113,21 @@ class GeminiLLM(LLMBase):
else: else:
return data return data
new_tools = []
if tools: if tools:
function_declarations = []
for tool in tools: for tool in tools:
func = tool["function"].copy() func = tool["function"].copy()
new_tools.append({"function_declarations": [remove_additional_properties(func)]}) cleaned_func = remove_additional_properties(func)
# TODO: temporarily ignore it to pass tests, will come back to update according to standards later. function_declaration = types.FunctionDeclaration(
# return content_types.to_function_library(new_tools) name=cleaned_func["name"],
description=cleaned_func.get("description", ""),
parameters=cleaned_func.get("parameters", {}),
)
function_declarations.append(function_declaration)
return new_tools tool_obj = types.Tool(function_declarations=function_declarations)
return [tool_obj]
else: else:
return None return None
@@ -141,38 +151,53 @@ class GeminiLLM(LLMBase):
str: The generated response. str: The generated response.
""" """
params = { # Extract system instruction and reformat messages
system_instruction, contents = self._reformat_messages(messages)
# Prepare generation config
config_params = {
"temperature": self.config.temperature, "temperature": self.config.temperature,
"max_output_tokens": self.config.max_tokens, "max_output_tokens": self.config.max_tokens,
"top_p": self.config.top_p, "top_p": self.config.top_p,
} }
# Add system instruction to config if present
if system_instruction:
config_params["system_instruction"] = system_instruction
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
params["response_mime_type"] = "application/json" config_params["response_mime_type"] = "application/json"
if "schema" in response_format: if "schema" in response_format:
params["response_schema"] = response_format["schema"] config_params["response_schema"] = response_format["schema"]
tool_config = None if tools:
if tool_choice: formatted_tools = self._reformat_tools(tools)
tool_config = types.ToolConfig( config_params["tools"] = formatted_tools
function_calling_config=types.FunctionCallingConfig(
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
allowed_function_names=[tool["function"]["name"] for tool in tools] if tool_choice:
if tool_choice == "any" if tool_choice == "auto":
else None, mode = types.FunctionCallingConfigMode.AUTO
elif tool_choice == "any":
mode = types.FunctionCallingConfigMode.ANY
else:
mode = types.FunctionCallingConfigMode.NONE
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=mode,
allowed_function_names=(
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
),
)
) )
) config_params["tool_config"] = tool_config
response = self.client_gemini.models.generate_content( generation_config = types.GenerateContentConfig(**config_params)
model=self.config.model,
contents=self._reformat_messages(messages), response = self.client.models.generate_content(
config=types.GenerateContentConfig( model=self.config.model, contents=contents, config=generation_config
temperature=self.config.temperature,
max_output_tokens=self.config.max_tokens,
top_p=self.config.top_p,
tools=self._reformat_tools(tools),
tool_config=tool_config,
),
) )
return self._parse_response(response, tools) return self._parse_response(response, tools)