Fix: Migrate Gemini Embeddings (#3002)
Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
@@ -39,5 +39,5 @@ Here are the parameters available for configuring Gemini embedder:
|
||||
| Parameter | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `model` | The name of the embedding model to use | `models/text-embedding-004` |
|
||||
| `embedding_dims` | Dimensions of the embedding model | `768` |
|
||||
| `embedding_dims` | Dimensions of the embedding model (output_dimensionality will be considered as embedding_dims, so please set embedding_dims accordingly) | `768` |
|
||||
| `api_key` | The Gemini API key | `None` |
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
|
||||
import google.generativeai as genai
|
||||
import google.genai as genai
|
||||
|
||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||
from mem0.embeddings.base import EmbeddingBase
|
||||
@@ -12,23 +12,28 @@ class GoogleGenAIEmbedding(EmbeddingBase):
|
||||
super().__init__(config)
|
||||
|
||||
self.config.model = self.config.model or "models/text-embedding-004"
|
||||
self.config.embedding_dims = self.config.embedding_dims or 768
|
||||
self.config.embedding_dims = self.config.embedding_dims or self.config.output_dimensionality or 768
|
||||
|
||||
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
genai.configure(api_key=api_key)
|
||||
if api_key:
|
||||
self.client = genai.Client(api_key="api_key")
|
||||
else:
|
||||
self.client = genai.Client()
|
||||
|
||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||
"""
|
||||
Get the embedding for the given text using Google Generative AI.
|
||||
Args:
|
||||
text (str): The text to embed.
|
||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
||||
memory_action (optional): The type of embedding to use. (Currently not used by Gemini for task_type)
|
||||
Returns:
|
||||
list: The embedding vector.
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
response = genai.embed_content(
|
||||
|
||||
response = self.client.models.embed_content(
|
||||
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
|
||||
)
|
||||
|
||||
return response["embedding"]
|
||||
|
||||
@@ -49,16 +49,17 @@ class GeminiLLM(LLMBase):
|
||||
for part in candidate.content.parts:
|
||||
fn = getattr(part, "function_call", None)
|
||||
if fn:
|
||||
processed_response["tool_calls"].append({
|
||||
processed_response["tool_calls"].append(
|
||||
{
|
||||
"name": fn.name,
|
||||
"arguments": fn.args,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return processed_response
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
|
||||
"""
|
||||
Reformat messages for Gemini using google.genai.types.
|
||||
@@ -78,15 +79,11 @@ class GeminiLLM(LLMBase):
|
||||
content = message["content"]
|
||||
|
||||
new_messages.append(
|
||||
types.Content(
|
||||
role="model" if message["role"] == "model" else "user",
|
||||
parts=[types.Part(text=content)]
|
||||
)
|
||||
types.Content(role="model" if message["role"] == "model" else "user", parts=[types.Part(text=content)])
|
||||
)
|
||||
|
||||
return new_messages
|
||||
|
||||
|
||||
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
||||
"""
|
||||
Reformat tools for Gemini.
|
||||
@@ -131,7 +128,6 @@ class GeminiLLM(LLMBase):
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: str = "auto",
|
||||
):
|
||||
|
||||
"""
|
||||
Generate a response based on the given messages using Gemini.
|
||||
|
||||
@@ -161,31 +157,22 @@ class GeminiLLM(LLMBase):
|
||||
tool_config = types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
|
||||
allowed_function_names=[
|
||||
tool["function"]["name"] for tool in tools
|
||||
] if tool_choice == "any" else None
|
||||
allowed_function_names=[tool["function"]["name"] for tool in tools]
|
||||
if tool_choice == "any"
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
print(f"Tool config: {tool_config}")
|
||||
print(f"Params: {params}" )
|
||||
print(f"Messages: {messages}")
|
||||
print(f"Tools: {tools}")
|
||||
print(f"Reformatted messages: {self._reformat_messages(messages)}")
|
||||
print(f"Reformatted tools: {self._reformat_tools(tools)}")
|
||||
|
||||
response = self.client_gemini.models.generate_content(
|
||||
model=self.config.model,
|
||||
contents=self._reformat_messages(messages),
|
||||
config=types.GenerateContentConfig(
|
||||
temperature= self.config.temperature,
|
||||
max_output_tokens= self.config.max_tokens,
|
||||
top_p= self.config.top_p,
|
||||
temperature=self.config.temperature,
|
||||
max_output_tokens=self.config.max_tokens,
|
||||
top_p=self.config.top_p,
|
||||
tools=self._reformat_tools(tools),
|
||||
tool_config=tool_config,
|
||||
|
||||
),
|
||||
)
|
||||
print(f"Response test: {response}")
|
||||
|
||||
return self._parse_response(response, tools)
|
||||
|
||||
@@ -28,3 +28,29 @@ def test_embed_query(mock_genai, config):
|
||||
|
||||
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
||||
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
|
||||
|
||||
def test_embed_returns_empty_list_if_none(mock_genai, config):
|
||||
mock_genai.return_value = None
|
||||
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
result = embedder.embed("test")
|
||||
|
||||
assert result == []
|
||||
mock_genai.assert_called_once()
|
||||
|
||||
|
||||
def test_embed_raises_on_error(mock_genai, config):
|
||||
mock_genai.side_effect = RuntimeError("Embedding failed")
|
||||
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Embedding failed"):
|
||||
embedder.embed("some input")
|
||||
|
||||
def test_config_initialization(config):
|
||||
embedder = GoogleGenAIEmbedding(config)
|
||||
|
||||
assert embedder.config.api_key == "dummy_api_key"
|
||||
assert embedder.config.model == "test_model"
|
||||
assert embedder.config.embedding_dims == 786
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from google.generativeai import GenerationConfig
|
||||
from google.generativeai.types import content_types
|
||||
from google.genai import types
|
||||
|
||||
from mem0.configs.llms.base import BaseLlmConfig
|
||||
from mem0.llms.gemini import GeminiLLM
|
||||
@@ -10,14 +9,14 @@ from mem0.llms.gemini import GeminiLLM
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gemini_client():
|
||||
with patch("mem0.llms.gemini.GenerativeModel") as mock_gemini:
|
||||
with patch("mem0.llms.gemini.genai") as mock_client_class:
|
||||
mock_client = Mock()
|
||||
mock_gemini.return_value = mock_client
|
||||
mock_client_class.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
config = BaseLlmConfig(model="gemini-2.0-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = GeminiLLM(config)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
@@ -25,6 +24,15 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||
]
|
||||
|
||||
mock_part = Mock(text="I'm doing well, thank you for asking!")
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.values = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [Mock()]
|
||||
mock_response.candidates[0].content.parts = [Mock()]
|
||||
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
|
||||
|
||||
mock_gemini_client.models.generate_content.return_value = mock_response
|
||||
mock_content = Mock(parts=[mock_part])
|
||||
mock_message = Mock(content=mock_content)
|
||||
mock_response = Mock(candidates=[mock_message])
|
||||
@@ -37,15 +45,24 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||
{"parts": "Hello, how are you?", "role": "user"},
|
||||
],
|
||||
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.7,
|
||||
max_output_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=None,
|
||||
tool_config=content_types.to_tool_config(
|
||||
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
||||
),
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
allowed_function_names=None,
|
||||
mode="auto"
|
||||
|
||||
)
|
||||
)
|
||||
) )
|
||||
|
||||
assert response == "I'm doing well, thank you for asking!"
|
||||
|
||||
|
||||
|
||||
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||
llm = GeminiLLM(config)
|
||||
@@ -89,28 +106,46 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||
|
||||
mock_gemini_client.generate_content.assert_called_once_with(
|
||||
contents=[
|
||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
|
||||
],
|
||||
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
||||
tools=[
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": "add_memory",
|
||||
"description": "Add a memory",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
||||
"required": ["data"],
|
||||
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"parts": "Add a new memory: Today is a sunny day.",
|
||||
"role": "user"
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
tool_config=content_types.to_tool_config(
|
||||
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
||||
),
|
||||
config=types.GenerateContentConfig(
|
||||
temperature=0.7,
|
||||
max_output_tokens=100,
|
||||
top_p=1.0,
|
||||
tools=[
|
||||
types.Tool(
|
||||
function_declarations=[
|
||||
types.FunctionDeclaration(
|
||||
name="add_memory",
|
||||
description="Add a memory",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "string",
|
||||
"description": "Data to add to memory"
|
||||
}
|
||||
},
|
||||
"required": ["data"]
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
],
|
||||
tool_config=types.ToolConfig(
|
||||
function_calling_config=types.FunctionCallingConfig(
|
||||
allowed_function_names=None,
|
||||
mode="auto"
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assert response["content"] == "I've added the memory for you."
|
||||
Reference in New Issue
Block a user