diff --git a/mem0/embeddings/gemini.py b/mem0/embeddings/gemini.py index 1082ebac..59444b42 100644 --- a/mem0/embeddings/gemini.py +++ b/mem0/embeddings/gemini.py @@ -1,7 +1,8 @@ import os from typing import Literal, Optional -import google.genai as genai +from google import genai +from google.genai import types from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.embeddings.base import EmbeddingBase @@ -16,24 +17,23 @@ class GoogleGenAIEmbedding(EmbeddingBase): api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY") - if api_key: - self.client = genai.Client(api_key="api_key") - else: - self.client = genai.Client() + self.client = genai.Client(api_key=api_key) def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): """ Get the embedding for the given text using Google Generative AI. Args: text (str): The text to embed. - memory_action (optional): The type of embedding to use. (Currently not used by Gemini for task_type) + memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None. Returns: list: The embedding vector. """ text = text.replace("\n", " ") - response = self.client.models.embed_content( - model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims - ) + # Create config for embedding parameters + config = types.EmbedContentConfig(output_dimensionality=self.config.embedding_dims) - return response["embedding"] + # Call the embed_content method with the correct parameters + response = self.client.models.embed_content(model=self.config.model, contents=text, config=config) + + return response.embeddings[0].values \ No newline at end of file diff --git a/mem0/llms/gemini.py b/mem0/llms/gemini.py index 34a9c0cf..7085f158 100644 --- a/mem0/llms/gemini.py +++ b/mem0/llms/gemini.py @@ -4,11 +4,8 @@ from typing import Dict, List, Optional try: from google import genai from google.genai import types - except ImportError: - raise ImportError( - "The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'." - ) + raise ImportError("The 'google-genai' library is required. Please install it using 'pip install google-genai'.") from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.base import LLMBase @@ -19,70 +16,79 @@ class GeminiLLM(LLMBase): super().__init__(config) if not self.config.model: - self.config.model = "gemini-1.5-flash-latest" + self.config.model = "gemini-2.0-flash" - api_key = self.config.api_key or os.getenv("GEMINI_API_KEY") - self.client_gemini = genai.Client( - api_key=api_key, - ) + api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY") + self.client = genai.Client(api_key=api_key) def _parse_response(self, response, tools): """ Process the response based on whether tools are used or not. Args: - response: The raw response from the API. + response: The raw response from API. tools: The list of tools provided in the request. Returns: str or dict: The processed response. """ - candidate = response.candidates[0] - content = candidate.content.parts[0].text if candidate.content.parts else None - if tools: processed_response = { - "content": content, + "content": None, "tool_calls": [], } - for part in candidate.content.parts: - fn = getattr(part, "function_call", None) - if fn: - processed_response["tool_calls"].append( - { - "name": fn.name, - "arguments": fn.args, - } - ) + # Extract content from the first candidate + if response.candidates and response.candidates[0].content.parts: + for part in response.candidates[0].content.parts: + if hasattr(part, "text") and part.text: + processed_response["content"] = part.text + break + + # Extract function calls + if response.candidates and response.candidates[0].content.parts: + for part in response.candidates[0].content.parts: + if hasattr(part, "function_call") and part.function_call: + fn = part.function_call + processed_response["tool_calls"].append( + { + "name": fn.name, + "arguments": dict(fn.args) if fn.args else {}, + } + ) return processed_response + else: + if response.candidates and response.candidates[0].content.parts: + for part in response.candidates[0].content.parts: + if hasattr(part, "text") and part.text: + return part.text + return "" - return content - - def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]: + def _reformat_messages(self, messages: List[Dict[str, str]]): """ - Reformat messages for Gemini using google.genai.types. + Reformat messages for Gemini. Args: messages: The list of messages provided in the request. Returns: - list: A list of types.Content objects with proper role and parts. + tuple: (system_instruction, contents_list) """ - new_messages = [] + system_instruction = None + contents = [] for message in messages: if message["role"] == "system": - content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"] + system_instruction = message["content"] else: - content = message["content"] + content = types.Content( + parts=[types.Part(text=message["content"])], + role=message["role"], + ) + contents.append(content) - new_messages.append( - types.Content(role="model" if message["role"] == "model" else "user", parts=[types.Part(text=content)]) - ) - - return new_messages + return system_instruction, contents def _reformat_tools(self, tools: Optional[List[Dict]]): """ @@ -97,7 +103,6 @@ class GeminiLLM(LLMBase): def remove_additional_properties(data): """Recursively removes 'additionalProperties' from nested dictionaries.""" - if isinstance(data, dict): filtered_dict = { key: remove_additional_properties(value) @@ -108,16 +113,21 @@ class GeminiLLM(LLMBase): else: return data - new_tools = [] if tools: + function_declarations = [] for tool in tools: func = tool["function"].copy() - new_tools.append({"function_declarations": [remove_additional_properties(func)]}) + cleaned_func = remove_additional_properties(func) - # TODO: temporarily ignore it to pass tests, will come back to update according to standards later. - # return content_types.to_function_library(new_tools) + function_declaration = types.FunctionDeclaration( + name=cleaned_func["name"], + description=cleaned_func.get("description", ""), + parameters=cleaned_func.get("parameters", {}), + ) + function_declarations.append(function_declaration) - return new_tools + tool_obj = types.Tool(function_declarations=function_declarations) + return [tool_obj] else: return None @@ -141,38 +151,53 @@ class GeminiLLM(LLMBase): str: The generated response. """ - params = { + # Extract system instruction and reformat messages + system_instruction, contents = self._reformat_messages(messages) + + # Prepare generation config + config_params = { "temperature": self.config.temperature, "max_output_tokens": self.config.max_tokens, "top_p": self.config.top_p, } + # Add system instruction to config if present + if system_instruction: + config_params["system_instruction"] = system_instruction + + if response_format is not None and response_format["type"] == "json_object": - params["response_mime_type"] = "application/json" + config_params["response_mime_type"] = "application/json" if "schema" in response_format: - params["response_schema"] = response_format["schema"] + config_params["response_schema"] = response_format["schema"] - tool_config = None - if tool_choice: - tool_config = types.ToolConfig( - function_calling_config=types.FunctionCallingConfig( - mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc. - allowed_function_names=[tool["function"]["name"] for tool in tools] - if tool_choice == "any" - else None, + if tools: + formatted_tools = self._reformat_tools(tools) + config_params["tools"] = formatted_tools + + + if tool_choice: + if tool_choice == "auto": + mode = types.FunctionCallingConfigMode.AUTO + elif tool_choice == "any": + mode = types.FunctionCallingConfigMode.ANY + else: + mode = types.FunctionCallingConfigMode.NONE + + tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=mode, + allowed_function_names=( + [tool["function"]["name"] for tool in tools] if tool_choice == "any" else None + ), + ) ) - ) + config_params["tool_config"] = tool_config - response = self.client_gemini.models.generate_content( - model=self.config.model, - contents=self._reformat_messages(messages), - config=types.GenerateContentConfig( - temperature=self.config.temperature, - max_output_tokens=self.config.max_tokens, - top_p=self.config.top_p, - tools=self._reformat_tools(tools), - tool_config=tool_config, - ), + generation_config = types.GenerateContentConfig(**config_params) + + response = self.client.models.generate_content( + model=self.config.model, contents=contents, config=generation_config ) return self._parse_response(response, tools)