From 386d8b87ae613b23ddeead3f7989d2ab19fec484 Mon Sep 17 00:00:00 2001 From: Akshat Jain <125379408+akshat1423@users.noreply.github.com> Date: Mon, 23 Jun 2025 13:16:10 +0530 Subject: [PATCH] Fix: Migrate Gemini Embeddings (#3002) Co-authored-by: Dev-Khant --- docs/components/embedders/models/gemini.mdx | 2 +- mem0/embeddings/gemini.py | 15 ++- mem0/llms/gemini.py | 55 ++++------- ...est_gemini.py => test_gemini_emeddings.py} | 26 +++++ .../{test_gemini_llm.py => test_gemini.py} | 97 +++++++++++++------ 5 files changed, 124 insertions(+), 71 deletions(-) rename tests/embeddings/{test_gemini.py => test_gemini_emeddings.py} (53%) rename tests/llms/{test_gemini_llm.py => test_gemini.py} (54%) diff --git a/docs/components/embedders/models/gemini.mdx b/docs/components/embedders/models/gemini.mdx index 077f9ca5..90ee5553 100644 --- a/docs/components/embedders/models/gemini.mdx +++ b/docs/components/embedders/models/gemini.mdx @@ -39,5 +39,5 @@ Here are the parameters available for configuring Gemini embedder: | Parameter | Description | Default Value | | --- | --- | --- | | `model` | The name of the embedding model to use | `models/text-embedding-004` | -| `embedding_dims` | Dimensions of the embedding model | `768` | +| `embedding_dims` | Dimensions of the embedding model (output_dimensionality will be considered as embedding_dims, so please set embedding_dims accordingly) | `768` | | `api_key` | The Gemini API key | `None` | diff --git a/mem0/embeddings/gemini.py b/mem0/embeddings/gemini.py index 5bc87e20..1082ebac 100644 --- a/mem0/embeddings/gemini.py +++ b/mem0/embeddings/gemini.py @@ -1,7 +1,7 @@ import os from typing import Literal, Optional -import google.generativeai as genai +import google.genai as genai from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.embeddings.base import EmbeddingBase @@ -12,23 +12,28 @@ class GoogleGenAIEmbedding(EmbeddingBase): super().__init__(config) self.config.model = self.config.model or "models/text-embedding-004" - self.config.embedding_dims = self.config.embedding_dims or 768 + self.config.embedding_dims = self.config.embedding_dims or self.config.output_dimensionality or 768 api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY") - genai.configure(api_key=api_key) + if api_key: + self.client = genai.Client(api_key="api_key") + else: + self.client = genai.Client() def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): """ Get the embedding for the given text using Google Generative AI. Args: text (str): The text to embed. - memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None. + memory_action (optional): The type of embedding to use. (Currently not used by Gemini for task_type) Returns: list: The embedding vector. """ text = text.replace("\n", " ") - response = genai.embed_content( + + response = self.client.models.embed_content( model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims ) + return response["embedding"] diff --git a/mem0/llms/gemini.py b/mem0/llms/gemini.py index 3c48c5da..34a9c0cf 100644 --- a/mem0/llms/gemini.py +++ b/mem0/llms/gemini.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional try: from google import genai from google.genai import types - + except ImportError: raise ImportError( "The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'." @@ -49,16 +49,17 @@ class GeminiLLM(LLMBase): for part in candidate.content.parts: fn = getattr(part, "function_call", None) if fn: - processed_response["tool_calls"].append({ - "name": fn.name, - "arguments": fn.args, - }) + processed_response["tool_calls"].append( + { + "name": fn.name, + "arguments": fn.args, + } + ) return processed_response return content - def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]: """ Reformat messages for Gemini using google.genai.types. @@ -78,15 +79,11 @@ class GeminiLLM(LLMBase): content = message["content"] new_messages.append( - types.Content( - role="model" if message["role"] == "model" else "user", - parts=[types.Part(text=content)] - ) + types.Content(role="model" if message["role"] == "model" else "user", parts=[types.Part(text=content)]) ) return new_messages - def _reformat_tools(self, tools: Optional[List[Dict]]): """ Reformat tools for Gemini. @@ -131,7 +128,6 @@ class GeminiLLM(LLMBase): tools: Optional[List[Dict]] = None, tool_choice: str = "auto", ): - """ Generate a response based on the given messages using Gemini. @@ -161,31 +157,22 @@ class GeminiLLM(LLMBase): tool_config = types.ToolConfig( function_calling_config=types.FunctionCallingConfig( mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc. - allowed_function_names=[ - tool["function"]["name"] for tool in tools - ] if tool_choice == "any" else None + allowed_function_names=[tool["function"]["name"] for tool in tools] + if tool_choice == "any" + else None, ) ) - print(f"Tool config: {tool_config}") - print(f"Params: {params}" ) - print(f"Messages: {messages}") - print(f"Tools: {tools}") - print(f"Reformatted messages: {self._reformat_messages(messages)}") - print(f"Reformatted tools: {self._reformat_tools(tools)}") - response = self.client_gemini.models.generate_content( - model=self.config.model, - contents=self._reformat_messages(messages), - config=types.GenerateContentConfig( - temperature= self.config.temperature, - max_output_tokens= self.config.max_tokens, - top_p= self.config.top_p, - tools=self._reformat_tools(tools), - tool_config=tool_config, - - ), - ) - print(f"Response test: {response}") + model=self.config.model, + contents=self._reformat_messages(messages), + config=types.GenerateContentConfig( + temperature=self.config.temperature, + max_output_tokens=self.config.max_tokens, + top_p=self.config.top_p, + tools=self._reformat_tools(tools), + tool_config=tool_config, + ), + ) return self._parse_response(response, tools) diff --git a/tests/embeddings/test_gemini.py b/tests/embeddings/test_gemini_emeddings.py similarity index 53% rename from tests/embeddings/test_gemini.py rename to tests/embeddings/test_gemini_emeddings.py index a49e585a..32c55982 100644 --- a/tests/embeddings/test_gemini.py +++ b/tests/embeddings/test_gemini_emeddings.py @@ -28,3 +28,29 @@ def test_embed_query(mock_genai, config): assert embedding == [0.1, 0.2, 0.3, 0.4] mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786) + +def test_embed_returns_empty_list_if_none(mock_genai, config): + mock_genai.return_value = None + + embedder = GoogleGenAIEmbedding(config) + result = embedder.embed("test") + + assert result == [] + mock_genai.assert_called_once() + + +def test_embed_raises_on_error(mock_genai, config): + mock_genai.side_effect = RuntimeError("Embedding failed") + + embedder = GoogleGenAIEmbedding(config) + + with pytest.raises(RuntimeError, match="Embedding failed"): + embedder.embed("some input") + +def test_config_initialization(config): + embedder = GoogleGenAIEmbedding(config) + + assert embedder.config.api_key == "dummy_api_key" + assert embedder.config.model == "test_model" + assert embedder.config.embedding_dims == 786 + diff --git a/tests/llms/test_gemini_llm.py b/tests/llms/test_gemini.py similarity index 54% rename from tests/llms/test_gemini_llm.py rename to tests/llms/test_gemini.py index ffdec4fb..b086d150 100644 --- a/tests/llms/test_gemini_llm.py +++ b/tests/llms/test_gemini.py @@ -1,8 +1,7 @@ from unittest.mock import Mock, patch import pytest -from google.generativeai import GenerationConfig -from google.generativeai.types import content_types +from google.genai import types from mem0.configs.llms.base import BaseLlmConfig from mem0.llms.gemini import GeminiLLM @@ -10,14 +9,14 @@ from mem0.llms.gemini import GeminiLLM @pytest.fixture def mock_gemini_client(): - with patch("mem0.llms.gemini.GenerativeModel") as mock_gemini: + with patch("mem0.llms.gemini.genai") as mock_client_class: mock_client = Mock() - mock_gemini.return_value = mock_client + mock_client_class.return_value = mock_client yield mock_client def test_generate_response_without_tools(mock_gemini_client: Mock): - config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) + config = BaseLlmConfig(model="gemini-2.0-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) llm = GeminiLLM(config) messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -25,6 +24,15 @@ def test_generate_response_without_tools(mock_gemini_client: Mock): ] mock_part = Mock(text="I'm doing well, thank you for asking!") + mock_embedding = Mock() + mock_embedding.values = [0.1, 0.2, 0.3] + + mock_response = Mock() + mock_response.candidates = [Mock()] + mock_response.candidates[0].content.parts = [Mock()] + mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!" + + mock_gemini_client.models.generate_content.return_value = mock_response mock_content = Mock(parts=[mock_part]) mock_message = Mock(content=mock_content) mock_response = Mock(candidates=[mock_message]) @@ -37,15 +45,24 @@ def test_generate_response_without_tools(mock_gemini_client: Mock): {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, {"parts": "Hello, how are you?", "role": "user"}, ], - generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), - tools=None, - tool_config=content_types.to_tool_config( - {"function_calling_config": {"mode": "auto", "allowed_function_names": None}} - ), - ) + config=types.GenerateContentConfig( + temperature=0.7, + max_output_tokens=100, + top_p=1.0, + tools=None, + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + allowed_function_names=None, + mode="auto" + + ) + ) + ) ) + assert response == "I'm doing well, thank you for asking!" + def test_generate_response_with_tools(mock_gemini_client: Mock): config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) llm = GeminiLLM(config) @@ -89,28 +106,46 @@ def test_generate_response_with_tools(mock_gemini_client: Mock): mock_gemini_client.generate_content.assert_called_once_with( contents=[ - {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, - {"parts": "Add a new memory: Today is a sunny day.", "role": "user"}, - ], - generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), - tools=[ { - "function_declarations": [ - { - "name": "add_memory", - "description": "Add a memory", - "parameters": { - "type": "object", - "properties": {"data": {"type": "string", "description": "Data to add to memory"}}, - "required": ["data"], - }, - } - ] - } + "parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", + "role": "user" + }, + { + "parts": "Add a new memory: Today is a sunny day.", + "role": "user" + }, ], - tool_config=content_types.to_tool_config( - {"function_calling_config": {"mode": "auto", "allowed_function_names": None}} - ), + config=types.GenerateContentConfig( + temperature=0.7, + max_output_tokens=100, + top_p=1.0, + tools=[ + types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="add_memory", + description="Add a memory", + parameters={ + "type": "object", + "properties": { + "data": { + "type": "string", + "description": "Data to add to memory" + } + }, + "required": ["data"] + } + ) + ] + ) + ], + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + allowed_function_names=None, + mode="auto" + ) + ) + ) ) assert response["content"] == "I've added the memory for you."