Fix: Migrate Gemini Embeddings (#3002)
Co-authored-by: Dev-Khant <devkhant24@gmail.com>
This commit is contained in:
@@ -39,5 +39,5 @@ Here are the parameters available for configuring Gemini embedder:
|
|||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default Value |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `model` | The name of the embedding model to use | `models/text-embedding-004` |
|
| `model` | The name of the embedding model to use | `models/text-embedding-004` |
|
||||||
| `embedding_dims` | Dimensions of the embedding model | `768` |
|
| `embedding_dims` | Dimensions of the embedding model (output_dimensionality will be considered as embedding_dims, so please set embedding_dims accordingly) | `768` |
|
||||||
| `api_key` | The Gemini API key | `None` |
|
| `api_key` | The Gemini API key | `None` |
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.genai as genai
|
||||||
|
|
||||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
from mem0.embeddings.base import EmbeddingBase
|
from mem0.embeddings.base import EmbeddingBase
|
||||||
@@ -12,23 +12,28 @@ class GoogleGenAIEmbedding(EmbeddingBase):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.config.model = self.config.model or "models/text-embedding-004"
|
self.config.model = self.config.model or "models/text-embedding-004"
|
||||||
self.config.embedding_dims = self.config.embedding_dims or 768
|
self.config.embedding_dims = self.config.embedding_dims or self.config.output_dimensionality or 768
|
||||||
|
|
||||||
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
|
||||||
|
|
||||||
genai.configure(api_key=api_key)
|
if api_key:
|
||||||
|
self.client = genai.Client(api_key="api_key")
|
||||||
|
else:
|
||||||
|
self.client = genai.Client()
|
||||||
|
|
||||||
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
|
||||||
"""
|
"""
|
||||||
Get the embedding for the given text using Google Generative AI.
|
Get the embedding for the given text using Google Generative AI.
|
||||||
Args:
|
Args:
|
||||||
text (str): The text to embed.
|
text (str): The text to embed.
|
||||||
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
|
memory_action (optional): The type of embedding to use. (Currently not used by Gemini for task_type)
|
||||||
Returns:
|
Returns:
|
||||||
list: The embedding vector.
|
list: The embedding vector.
|
||||||
"""
|
"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
response = genai.embed_content(
|
|
||||||
|
response = self.client.models.embed_content(
|
||||||
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
|
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
|
||||||
)
|
)
|
||||||
|
|
||||||
return response["embedding"]
|
return response["embedding"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
|
|||||||
try:
|
try:
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
|
||||||
@@ -49,16 +49,17 @@ class GeminiLLM(LLMBase):
|
|||||||
for part in candidate.content.parts:
|
for part in candidate.content.parts:
|
||||||
fn = getattr(part, "function_call", None)
|
fn = getattr(part, "function_call", None)
|
||||||
if fn:
|
if fn:
|
||||||
processed_response["tool_calls"].append({
|
processed_response["tool_calls"].append(
|
||||||
"name": fn.name,
|
{
|
||||||
"arguments": fn.args,
|
"name": fn.name,
|
||||||
})
|
"arguments": fn.args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return processed_response
|
return processed_response
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
|
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
|
||||||
"""
|
"""
|
||||||
Reformat messages for Gemini using google.genai.types.
|
Reformat messages for Gemini using google.genai.types.
|
||||||
@@ -78,15 +79,11 @@ class GeminiLLM(LLMBase):
|
|||||||
content = message["content"]
|
content = message["content"]
|
||||||
|
|
||||||
new_messages.append(
|
new_messages.append(
|
||||||
types.Content(
|
types.Content(role="model" if message["role"] == "model" else "user", parts=[types.Part(text=content)])
|
||||||
role="model" if message["role"] == "model" else "user",
|
|
||||||
parts=[types.Part(text=content)]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
|
|
||||||
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
def _reformat_tools(self, tools: Optional[List[Dict]]):
|
||||||
"""
|
"""
|
||||||
Reformat tools for Gemini.
|
Reformat tools for Gemini.
|
||||||
@@ -131,7 +128,6 @@ class GeminiLLM(LLMBase):
|
|||||||
tools: Optional[List[Dict]] = None,
|
tools: Optional[List[Dict]] = None,
|
||||||
tool_choice: str = "auto",
|
tool_choice: str = "auto",
|
||||||
):
|
):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Generate a response based on the given messages using Gemini.
|
Generate a response based on the given messages using Gemini.
|
||||||
|
|
||||||
@@ -161,31 +157,22 @@ class GeminiLLM(LLMBase):
|
|||||||
tool_config = types.ToolConfig(
|
tool_config = types.ToolConfig(
|
||||||
function_calling_config=types.FunctionCallingConfig(
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
|
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
|
||||||
allowed_function_names=[
|
allowed_function_names=[tool["function"]["name"] for tool in tools]
|
||||||
tool["function"]["name"] for tool in tools
|
if tool_choice == "any"
|
||||||
] if tool_choice == "any" else None
|
else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Tool config: {tool_config}")
|
|
||||||
print(f"Params: {params}" )
|
|
||||||
print(f"Messages: {messages}")
|
|
||||||
print(f"Tools: {tools}")
|
|
||||||
print(f"Reformatted messages: {self._reformat_messages(messages)}")
|
|
||||||
print(f"Reformatted tools: {self._reformat_tools(tools)}")
|
|
||||||
|
|
||||||
response = self.client_gemini.models.generate_content(
|
response = self.client_gemini.models.generate_content(
|
||||||
model=self.config.model,
|
model=self.config.model,
|
||||||
contents=self._reformat_messages(messages),
|
contents=self._reformat_messages(messages),
|
||||||
config=types.GenerateContentConfig(
|
config=types.GenerateContentConfig(
|
||||||
temperature= self.config.temperature,
|
temperature=self.config.temperature,
|
||||||
max_output_tokens= self.config.max_tokens,
|
max_output_tokens=self.config.max_tokens,
|
||||||
top_p= self.config.top_p,
|
top_p=self.config.top_p,
|
||||||
tools=self._reformat_tools(tools),
|
tools=self._reformat_tools(tools),
|
||||||
tool_config=tool_config,
|
tool_config=tool_config,
|
||||||
|
),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
print(f"Response test: {response}")
|
|
||||||
|
|
||||||
return self._parse_response(response, tools)
|
return self._parse_response(response, tools)
|
||||||
|
|||||||
@@ -28,3 +28,29 @@ def test_embed_query(mock_genai, config):
|
|||||||
|
|
||||||
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
assert embedding == [0.1, 0.2, 0.3, 0.4]
|
||||||
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
|
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
|
||||||
|
|
||||||
|
def test_embed_returns_empty_list_if_none(mock_genai, config):
|
||||||
|
mock_genai.return_value = None
|
||||||
|
|
||||||
|
embedder = GoogleGenAIEmbedding(config)
|
||||||
|
result = embedder.embed("test")
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
mock_genai.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_raises_on_error(mock_genai, config):
|
||||||
|
mock_genai.side_effect = RuntimeError("Embedding failed")
|
||||||
|
|
||||||
|
embedder = GoogleGenAIEmbedding(config)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Embedding failed"):
|
||||||
|
embedder.embed("some input")
|
||||||
|
|
||||||
|
def test_config_initialization(config):
|
||||||
|
embedder = GoogleGenAIEmbedding(config)
|
||||||
|
|
||||||
|
assert embedder.config.api_key == "dummy_api_key"
|
||||||
|
assert embedder.config.model == "test_model"
|
||||||
|
assert embedder.config.embedding_dims == 786
|
||||||
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from google.generativeai import GenerationConfig
|
from google.genai import types
|
||||||
from google.generativeai.types import content_types
|
|
||||||
|
|
||||||
from mem0.configs.llms.base import BaseLlmConfig
|
from mem0.configs.llms.base import BaseLlmConfig
|
||||||
from mem0.llms.gemini import GeminiLLM
|
from mem0.llms.gemini import GeminiLLM
|
||||||
@@ -10,14 +9,14 @@ from mem0.llms.gemini import GeminiLLM
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_gemini_client():
|
def mock_gemini_client():
|
||||||
with patch("mem0.llms.gemini.GenerativeModel") as mock_gemini:
|
with patch("mem0.llms.gemini.genai") as mock_client_class:
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_gemini.return_value = mock_client
|
mock_client_class.return_value = mock_client
|
||||||
yield mock_client
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_without_tools(mock_gemini_client: Mock):
|
def test_generate_response_without_tools(mock_gemini_client: Mock):
|
||||||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
config = BaseLlmConfig(model="gemini-2.0-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
llm = GeminiLLM(config)
|
llm = GeminiLLM(config)
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
@@ -25,6 +24,15 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
|||||||
]
|
]
|
||||||
|
|
||||||
mock_part = Mock(text="I'm doing well, thank you for asking!")
|
mock_part = Mock(text="I'm doing well, thank you for asking!")
|
||||||
|
mock_embedding = Mock()
|
||||||
|
mock_embedding.values = [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.candidates = [Mock()]
|
||||||
|
mock_response.candidates[0].content.parts = [Mock()]
|
||||||
|
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
mock_gemini_client.models.generate_content.return_value = mock_response
|
||||||
mock_content = Mock(parts=[mock_part])
|
mock_content = Mock(parts=[mock_part])
|
||||||
mock_message = Mock(content=mock_content)
|
mock_message = Mock(content=mock_content)
|
||||||
mock_response = Mock(candidates=[mock_message])
|
mock_response = Mock(candidates=[mock_message])
|
||||||
@@ -37,15 +45,24 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
|
|||||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
||||||
{"parts": "Hello, how are you?", "role": "user"},
|
{"parts": "Hello, how are you?", "role": "user"},
|
||||||
],
|
],
|
||||||
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
config=types.GenerateContentConfig(
|
||||||
tools=None,
|
temperature=0.7,
|
||||||
tool_config=content_types.to_tool_config(
|
max_output_tokens=100,
|
||||||
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
top_p=1.0,
|
||||||
),
|
tools=None,
|
||||||
)
|
tool_config=types.ToolConfig(
|
||||||
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
allowed_function_names=None,
|
||||||
|
mode="auto"
|
||||||
|
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) )
|
||||||
|
|
||||||
assert response == "I'm doing well, thank you for asking!"
|
assert response == "I'm doing well, thank you for asking!"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
def test_generate_response_with_tools(mock_gemini_client: Mock):
|
||||||
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
|
||||||
llm = GeminiLLM(config)
|
llm = GeminiLLM(config)
|
||||||
@@ -89,28 +106,46 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
|
|||||||
|
|
||||||
mock_gemini_client.generate_content.assert_called_once_with(
|
mock_gemini_client.generate_content.assert_called_once_with(
|
||||||
contents=[
|
contents=[
|
||||||
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
|
|
||||||
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
|
|
||||||
],
|
|
||||||
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
|
|
||||||
tools=[
|
|
||||||
{
|
{
|
||||||
"function_declarations": [
|
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
|
||||||
{
|
"role": "user"
|
||||||
"name": "add_memory",
|
},
|
||||||
"description": "Add a memory",
|
{
|
||||||
"parameters": {
|
"parts": "Add a new memory: Today is a sunny day.",
|
||||||
"type": "object",
|
"role": "user"
|
||||||
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
|
},
|
||||||
"required": ["data"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
],
|
||||||
tool_config=content_types.to_tool_config(
|
config=types.GenerateContentConfig(
|
||||||
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
|
temperature=0.7,
|
||||||
),
|
max_output_tokens=100,
|
||||||
|
top_p=1.0,
|
||||||
|
tools=[
|
||||||
|
types.Tool(
|
||||||
|
function_declarations=[
|
||||||
|
types.FunctionDeclaration(
|
||||||
|
name="add_memory",
|
||||||
|
description="Add a memory",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Data to add to memory"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["data"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_config=types.ToolConfig(
|
||||||
|
function_calling_config=types.FunctionCallingConfig(
|
||||||
|
allowed_function_names=None,
|
||||||
|
mode="auto"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response["content"] == "I've added the memory for you."
|
assert response["content"] == "I've added the memory for you."
|
||||||
Reference in New Issue
Block a user