diff --git a/docs/mint.json b/docs/mint.json index 5e301898..1b982580 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -126,6 +126,7 @@ "components/embedders/models/azure_openai", "components/embedders/models/ollama", "components/embedders/models/huggingface", + "components/embedders/models/vertexai", "components/embedders/models/gemini" ] } @@ -231,4 +232,4 @@ "apiHost": "https://us.i.posthog.com" } } -} \ No newline at end of file +} diff --git a/mem0/configs/embeddings/base.py b/mem0/configs/embeddings/base.py index 63245872..a3d989ee 100644 --- a/mem0/configs/embeddings/base.py +++ b/mem0/configs/embeddings/base.py @@ -25,6 +25,8 @@ class BaseEmbedderConfig(ABC): # AzureOpenAI specific azure_kwargs: Optional[AzureConfig] = {}, http_client_proxies: Optional[Union[Dict, str]] = None, + # VertexAI specific + vertex_credentials_json: Optional[str] = None, ): """ Initializes a configuration class instance for the Embeddings. @@ -63,3 +65,6 @@ class BaseEmbedderConfig(ABC): # AzureOpenAI specific self.azure_kwargs = AzureConfig(**azure_kwargs) or {} + + # VertexAI specific + self.vertex_credentials_json = vertex_credentials_json diff --git a/mem0/embeddings/vertexai.py b/mem0/embeddings/vertexai.py index bcdaaab2..740e6303 100644 --- a/mem0/embeddings/vertexai.py +++ b/mem0/embeddings/vertexai.py @@ -7,7 +7,7 @@ from mem0.configs.embeddings.base import BaseEmbedderConfig from mem0.embeddings.base import EmbeddingBase -class VertexAI(EmbeddingBase): +class VertexAIEmbedding(EmbeddingBase): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config) diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index 034fc6fc..43f853c6 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -42,6 +42,7 @@ class EmbedderFactory: "huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding", "azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding", "gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding", + "vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding", } @classmethod diff --git a/tests/embeddings/test_vertexai_embeddings.py b/tests/embeddings/test_vertexai_embeddings.py index 4b85e077..26ac6341 100644 --- a/tests/embeddings/test_vertexai_embeddings.py +++ b/tests/embeddings/test_vertexai_embeddings.py @@ -1,6 +1,6 @@ import pytest from unittest.mock import Mock, patch -from mem0.embeddings.vertexai import VertexAI +from mem0.embeddings.vertexai import VertexAIEmbedding from mem0.configs.embeddings.base import BaseEmbedderConfig @@ -32,7 +32,7 @@ def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_co mock_config.return_value.embedding_dims = 256 config = mock_config() - embedder = VertexAI(config) + embedder = VertexAIEmbedding(config) mock_embedding = Mock(values=[0.1, 0.2, 0.3]) mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [ @@ -57,7 +57,7 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con config = mock_config() - embedder = VertexAI(config) + embedder = VertexAIEmbedding(config) mock_embedding = Mock(values=[0.4, 0.5, 0.6]) mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [ @@ -81,7 +81,7 @@ def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_c mock_os.getenv.return_value = "/path/to/env/credentials.json" mock_config.vertex_credentials_json = None config = mock_config() - VertexAI(config) + VertexAIEmbedding(config) mock_os.environ.setitem.assert_not_called() @@ -96,7 +96,7 @@ def test_missing_credentials(mock_os, mock_text_embedding_model, mock_config): with pytest.raises( ValueError, match="Google application credentials JSON is not provided" ): - VertexAI(config) + VertexAIEmbedding(config) @patch("mem0.embeddings.vertexai.TextEmbeddingModel") @@ -107,7 +107,7 @@ def test_embed_with_different_dimensions( mock_config.return_value.embedding_dims = 1024 config = mock_config() - embedder = VertexAI(config) + embedder = VertexAIEmbedding(config) mock_embedding = Mock(values=[0.1] * 1024) mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [