(bug-fix) : fix VertexAI missing configurations (#1926)
This commit is contained in:
@@ -126,6 +126,7 @@
|
|||||||
"components/embedders/models/azure_openai",
|
"components/embedders/models/azure_openai",
|
||||||
"components/embedders/models/ollama",
|
"components/embedders/models/ollama",
|
||||||
"components/embedders/models/huggingface",
|
"components/embedders/models/huggingface",
|
||||||
|
"components/embedders/models/vertexai",
|
||||||
"components/embedders/models/gemini"
|
"components/embedders/models/gemini"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -231,4 +232,4 @@
|
|||||||
"apiHost": "https://us.i.posthog.com"
|
"apiHost": "https://us.i.posthog.com"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ class BaseEmbedderConfig(ABC):
|
|||||||
# AzureOpenAI specific
|
# AzureOpenAI specific
|
||||||
azure_kwargs: Optional[AzureConfig] = {},
|
azure_kwargs: Optional[AzureConfig] = {},
|
||||||
http_client_proxies: Optional[Union[Dict, str]] = None,
|
http_client_proxies: Optional[Union[Dict, str]] = None,
|
||||||
|
# VertexAI specific
|
||||||
|
vertex_credentials_json: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes a configuration class instance for the Embeddings.
|
Initializes a configuration class instance for the Embeddings.
|
||||||
@@ -63,3 +65,6 @@ class BaseEmbedderConfig(ABC):
|
|||||||
|
|
||||||
# AzureOpenAI specific
|
# AzureOpenAI specific
|
||||||
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
|
self.azure_kwargs = AzureConfig(**azure_kwargs) or {}
|
||||||
|
|
||||||
|
# VertexAI specific
|
||||||
|
self.vertex_credentials_json = vertex_credentials_json
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from mem0.configs.embeddings.base import BaseEmbedderConfig
|
|||||||
from mem0.embeddings.base import EmbeddingBase
|
from mem0.embeddings.base import EmbeddingBase
|
||||||
|
|
||||||
|
|
||||||
class VertexAI(EmbeddingBase):
|
class VertexAIEmbedding(EmbeddingBase):
|
||||||
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class EmbedderFactory:
|
|||||||
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
|
"huggingface": "mem0.embeddings.huggingface.HuggingFaceEmbedding",
|
||||||
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
|
"azure_openai": "mem0.embeddings.azure_openai.AzureOpenAIEmbedding",
|
||||||
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
|
"gemini": "mem0.embeddings.gemini.GoogleGenAIEmbedding",
|
||||||
|
"vertexai": "mem0.embeddings.vertexai.VertexAIEmbedding",
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from mem0.embeddings.vertexai import VertexAI
|
from mem0.embeddings.vertexai import VertexAIEmbedding
|
||||||
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
from mem0.configs.embeddings.base import BaseEmbedderConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ def test_embed_default_model(mock_text_embedding_model, mock_os_environ, mock_co
|
|||||||
mock_config.return_value.embedding_dims = 256
|
mock_config.return_value.embedding_dims = 256
|
||||||
|
|
||||||
config = mock_config()
|
config = mock_config()
|
||||||
embedder = VertexAI(config)
|
embedder = VertexAIEmbedding(config)
|
||||||
|
|
||||||
mock_embedding = Mock(values=[0.1, 0.2, 0.3])
|
mock_embedding = Mock(values=[0.1, 0.2, 0.3])
|
||||||
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
|
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
|
||||||
@@ -57,7 +57,7 @@ def test_embed_custom_model(mock_text_embedding_model, mock_os_environ, mock_con
|
|||||||
|
|
||||||
config = mock_config()
|
config = mock_config()
|
||||||
|
|
||||||
embedder = VertexAI(config)
|
embedder = VertexAIEmbedding(config)
|
||||||
|
|
||||||
mock_embedding = Mock(values=[0.4, 0.5, 0.6])
|
mock_embedding = Mock(values=[0.4, 0.5, 0.6])
|
||||||
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
|
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
|
||||||
@@ -81,7 +81,7 @@ def test_credentials_from_environment(mock_os, mock_text_embedding_model, mock_c
|
|||||||
mock_os.getenv.return_value = "/path/to/env/credentials.json"
|
mock_os.getenv.return_value = "/path/to/env/credentials.json"
|
||||||
mock_config.vertex_credentials_json = None
|
mock_config.vertex_credentials_json = None
|
||||||
config = mock_config()
|
config = mock_config()
|
||||||
VertexAI(config)
|
VertexAIEmbedding(config)
|
||||||
|
|
||||||
mock_os.environ.setitem.assert_not_called()
|
mock_os.environ.setitem.assert_not_called()
|
||||||
|
|
||||||
@@ -96,7 +96,7 @@ def test_missing_credentials(mock_os, mock_text_embedding_model, mock_config):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Google application credentials JSON is not provided"
|
ValueError, match="Google application credentials JSON is not provided"
|
||||||
):
|
):
|
||||||
VertexAI(config)
|
VertexAIEmbedding(config)
|
||||||
|
|
||||||
|
|
||||||
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
|
@patch("mem0.embeddings.vertexai.TextEmbeddingModel")
|
||||||
@@ -107,7 +107,7 @@ def test_embed_with_different_dimensions(
|
|||||||
mock_config.return_value.embedding_dims = 1024
|
mock_config.return_value.embedding_dims = 1024
|
||||||
|
|
||||||
config = mock_config()
|
config = mock_config()
|
||||||
embedder = VertexAI(config)
|
embedder = VertexAIEmbedding(config)
|
||||||
|
|
||||||
mock_embedding = Mock(values=[0.1] * 1024)
|
mock_embedding = Mock(values=[0.1] * 1024)
|
||||||
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
|
mock_text_embedding_model.from_pretrained.return_value.get_embeddings.return_value = [
|
||||||
|
|||||||
Reference in New Issue
Block a user