Fix: Migrate Gemini Embeddings (#3002)

Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
Akshat Jain
2025-06-23 13:16:10 +05:30
committed by GitHub
parent c173ec32d0
commit 386d8b87ae
5 changed files with 124 additions and 71 deletions

View File

@@ -39,5 +39,5 @@ Here are the parameters available for configuring Gemini embedder:
| Parameter | Description | Default Value | | Parameter | Description | Default Value |
| --- | --- | --- | | --- | --- | --- |
| `model` | The name of the embedding model to use | `models/text-embedding-004` | | `model` | The name of the embedding model to use | `models/text-embedding-004` |
| `embedding_dims` | Dimensions of the embedding model | `768` | | `embedding_dims` | Dimensions of the embedding model (output_dimensionality will be considered as embedding_dims, so please set embedding_dims accordingly) | `768` |
| `api_key` | The Gemini API key | `None` | | `api_key` | The Gemini API key | `None` |

View File

@@ -1,7 +1,7 @@
import os import os
from typing import Literal, Optional from typing import Literal, Optional
import google.generativeai as genai import google.genai as genai
from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase from mem0.embeddings.base import EmbeddingBase
@@ -12,23 +12,28 @@ class GoogleGenAIEmbedding(EmbeddingBase):
super().__init__(config) super().__init__(config)
self.config.model = self.config.model or "models/text-embedding-004" self.config.model = self.config.model or "models/text-embedding-004"
self.config.embedding_dims = self.config.embedding_dims or 768 self.config.embedding_dims = self.config.embedding_dims or self.config.output_dimensionality or 768
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY") api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=api_key) if api_key:
self.client = genai.Client(api_key="api_key")
else:
self.client = genai.Client()
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
""" """
Get the embedding for the given text using Google Generative AI. Get the embedding for the given text using Google Generative AI.
Args: Args:
text (str): The text to embed. text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None. memory_action (optional): The type of embedding to use. (Currently not used by Gemini for task_type)
Returns: Returns:
list: The embedding vector. list: The embedding vector.
""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
response = genai.embed_content(
response = self.client.models.embed_content(
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
) )
return response["embedding"] return response["embedding"]

View File

@@ -49,16 +49,17 @@ class GeminiLLM(LLMBase):
for part in candidate.content.parts: for part in candidate.content.parts:
fn = getattr(part, "function_call", None) fn = getattr(part, "function_call", None)
if fn: if fn:
processed_response["tool_calls"].append({ processed_response["tool_calls"].append(
"name": fn.name, {
"arguments": fn.args, "name": fn.name,
}) "arguments": fn.args,
}
)
return processed_response return processed_response
return content return content
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]: def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
""" """
Reformat messages for Gemini using google.genai.types. Reformat messages for Gemini using google.genai.types.
@@ -78,15 +79,11 @@ class GeminiLLM(LLMBase):
content = message["content"] content = message["content"]
new_messages.append( new_messages.append(
types.Content( types.Content(role="model" if message["role"] == "model" else "user", parts=[types.Part(text=content)])
role="model" if message["role"] == "model" else "user",
parts=[types.Part(text=content)]
)
) )
return new_messages return new_messages
def _reformat_tools(self, tools: Optional[List[Dict]]): def _reformat_tools(self, tools: Optional[List[Dict]]):
""" """
Reformat tools for Gemini. Reformat tools for Gemini.
@@ -131,7 +128,6 @@ class GeminiLLM(LLMBase):
tools: Optional[List[Dict]] = None, tools: Optional[List[Dict]] = None,
tool_choice: str = "auto", tool_choice: str = "auto",
): ):
""" """
Generate a response based on the given messages using Gemini. Generate a response based on the given messages using Gemini.
@@ -161,31 +157,22 @@ class GeminiLLM(LLMBase):
tool_config = types.ToolConfig( tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig( function_calling_config=types.FunctionCallingConfig(
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc. mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
allowed_function_names=[ allowed_function_names=[tool["function"]["name"] for tool in tools]
tool["function"]["name"] for tool in tools if tool_choice == "any"
] if tool_choice == "any" else None else None,
) )
) )
print(f"Tool config: {tool_config}")
print(f"Params: {params}" )
print(f"Messages: {messages}")
print(f"Tools: {tools}")
print(f"Reformatted messages: {self._reformat_messages(messages)}")
print(f"Reformatted tools: {self._reformat_tools(tools)}")
response = self.client_gemini.models.generate_content( response = self.client_gemini.models.generate_content(
model=self.config.model, model=self.config.model,
contents=self._reformat_messages(messages), contents=self._reformat_messages(messages),
config=types.GenerateContentConfig( config=types.GenerateContentConfig(
temperature= self.config.temperature, temperature=self.config.temperature,
max_output_tokens= self.config.max_tokens, max_output_tokens=self.config.max_tokens,
top_p= self.config.top_p, top_p=self.config.top_p,
tools=self._reformat_tools(tools), tools=self._reformat_tools(tools),
tool_config=tool_config, tool_config=tool_config,
),
), )
)
print(f"Response test: {response}")
return self._parse_response(response, tools) return self._parse_response(response, tools)

View File

@@ -28,3 +28,29 @@ def test_embed_query(mock_genai, config):
assert embedding == [0.1, 0.2, 0.3, 0.4] assert embedding == [0.1, 0.2, 0.3, 0.4]
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786) mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
def test_embed_returns_empty_list_if_none(mock_genai, config):
mock_genai.return_value = None
embedder = GoogleGenAIEmbedding(config)
result = embedder.embed("test")
assert result == []
mock_genai.assert_called_once()
def test_embed_raises_on_error(mock_genai, config):
mock_genai.side_effect = RuntimeError("Embedding failed")
embedder = GoogleGenAIEmbedding(config)
with pytest.raises(RuntimeError, match="Embedding failed"):
embedder.embed("some input")
def test_config_initialization(config):
embedder = GoogleGenAIEmbedding(config)
assert embedder.config.api_key == "dummy_api_key"
assert embedder.config.model == "test_model"
assert embedder.config.embedding_dims == 786

View File

@@ -1,8 +1,7 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from google.generativeai import GenerationConfig from google.genai import types
from google.generativeai.types import content_types
from mem0.configs.llms.base import BaseLlmConfig from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.gemini import GeminiLLM from mem0.llms.gemini import GeminiLLM
@@ -10,14 +9,14 @@ from mem0.llms.gemini import GeminiLLM
@pytest.fixture @pytest.fixture
def mock_gemini_client(): def mock_gemini_client():
with patch("mem0.llms.gemini.GenerativeModel") as mock_gemini: with patch("mem0.llms.gemini.genai") as mock_client_class:
mock_client = Mock() mock_client = Mock()
mock_gemini.return_value = mock_client mock_client_class.return_value = mock_client
yield mock_client yield mock_client
def test_generate_response_without_tools(mock_gemini_client: Mock): def test_generate_response_without_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(model="gemini-2.0-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config) llm = GeminiLLM(config)
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@@ -25,6 +24,15 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
] ]
mock_part = Mock(text="I'm doing well, thank you for asking!") mock_part = Mock(text="I'm doing well, thank you for asking!")
mock_embedding = Mock()
mock_embedding.values = [0.1, 0.2, 0.3]
mock_response = Mock()
mock_response.candidates = [Mock()]
mock_response.candidates[0].content.parts = [Mock()]
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
mock_gemini_client.models.generate_content.return_value = mock_response
mock_content = Mock(parts=[mock_part]) mock_content = Mock(parts=[mock_part])
mock_message = Mock(content=mock_content) mock_message = Mock(content=mock_content)
mock_response = Mock(candidates=[mock_message]) mock_response = Mock(candidates=[mock_message])
@@ -37,15 +45,24 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"}, {"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Hello, how are you?", "role": "user"}, {"parts": "Hello, how are you?", "role": "user"},
], ],
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0), config=types.GenerateContentConfig(
tools=None, temperature=0.7,
tool_config=content_types.to_tool_config( max_output_tokens=100,
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}} top_p=1.0,
), tools=None,
) tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
)
)
) )
assert response == "I'm doing well, thank you for asking!" assert response == "I'm doing well, thank you for asking!"
def test_generate_response_with_tools(mock_gemini_client: Mock): def test_generate_response_with_tools(mock_gemini_client: Mock):
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0) config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
llm = GeminiLLM(config) llm = GeminiLLM(config)
@@ -89,28 +106,46 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
mock_gemini_client.generate_content.assert_called_once_with( mock_gemini_client.generate_content.assert_called_once_with(
contents=[ contents=[
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
],
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
tools=[
{ {
"function_declarations": [ "parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
{ "role": "user"
"name": "add_memory", },
"description": "Add a memory", {
"parameters": { "parts": "Add a new memory: Today is a sunny day.",
"type": "object", "role": "user"
"properties": {"data": {"type": "string", "description": "Data to add to memory"}}, },
"required": ["data"],
},
}
]
}
], ],
tool_config=content_types.to_tool_config( config=types.GenerateContentConfig(
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}} temperature=0.7,
), max_output_tokens=100,
top_p=1.0,
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="add_memory",
description="Add a memory",
parameters={
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "Data to add to memory"
}
},
"required": ["data"]
}
)
]
)
],
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
allowed_function_names=None,
mode="auto"
)
)
)
) )
assert response["content"] == "I've added the memory for you." assert response["content"] == "I've added the memory for you."