From 7c24601d0f2c3e6eb5d63bb9ed84732a3b6fc126 Mon Sep 17 00:00:00 2001 From: Pranav Puranik <54378813+PranavPuranik@users.noreply.github.com> Date: Sat, 29 Jun 2024 14:37:31 -0500 Subject: [PATCH] Adding model_kwargs for huggingface embedders. (#1450) --- docs/api-reference/advanced/configuration.mdx | 1 + docs/components/embedding-models.mdx | 2 ++ embedchain/config/embedder/base.py | 6 +++++- embedchain/embedder/huggingface.py | 3 ++- embedchain/utils/misc.py | 1 + tests/embedder/test_huggingface_embedder.py | 18 ++++++++++++++++++ 6 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 tests/embedder/test_huggingface_embedder.py diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index d9938d3e..097d5121 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -229,6 +229,7 @@ Alright, let's dive into what each key means in the yaml config above: - `deployment_name` (String): The deployment name for the embedding model. - `title` (String): The title for the embedding model for Google Embedder. - `task_type` (String): The task type for the embedding model for Google Embedder. + - `model_kwargs` (Dict): Used to pass extra arguments to embedders. 5. `chunker` Section: - `chunk_size` (Integer): The size of each chunk of text that is sent to the language model. - `chunk_overlap` (Integer): The amount of overlap between each chunk of text. diff --git a/docs/components/embedding-models.mdx b/docs/components/embedding-models.mdx index 96ee815a..5de69c7b 100644 --- a/docs/components/embedding-models.mdx +++ b/docs/components/embedding-models.mdx @@ -192,6 +192,8 @@ embedder: provider: huggingface config: model: 'sentence-transformers/all-mpnet-base-v2' + model_kwargs: + trust_remote_code: True # Only use if you trust your embedder ``` diff --git a/embedchain/config/embedder/base.py b/embedchain/config/embedder/base.py index 073be151..9365dec1 100644 --- a/embedchain/config/embedder/base.py +++ b/embedchain/config/embedder/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional from embedchain.helpers.json_serializable import register_deserializable @@ -13,6 +13,7 @@ class BaseEmbedderConfig: endpoint: Optional[str] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Initialize a new instance of an embedder config class. @@ -29,6 +30,8 @@ class BaseEmbedderConfig: :type api_key: Optional[str], optional :param api_base: huggingface api base, defaults to None :type api_base: Optional[str], optional + :param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init. + :type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init. """ self.model = model self.deployment_name = deployment_name @@ -36,3 +39,4 @@ class BaseEmbedderConfig: self.endpoint = endpoint self.api_key = api_key self.api_base = api_base + self.model_kwargs = model_kwargs or {} diff --git a/embedchain/embedder/huggingface.py b/embedchain/embedder/huggingface.py index cd11e5dd..062208e7 100644 --- a/embedchain/embedder/huggingface.py +++ b/embedchain/embedder/huggingface.py @@ -31,7 +31,8 @@ class HuggingFaceEmbedder(BaseEmbedder): huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"), ) else: - embeddings = HuggingFaceEmbeddings(model_name=self.config.model) + embeddings = HuggingFaceEmbeddings(model_name=self.config.model, model_kwargs=self.config.model_kwargs) + embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) self.set_embedding_fn(embedding_fn=embedding_fn) diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 8213a601..2de1f088 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -474,6 +474,7 @@ def validate_config(config_data): Optional("vector_dimension"): int, Optional("base_url"): str, Optional("endpoint"): str, + Optional("model_kwargs"): dict, }, }, Optional("embedding_model"): { diff --git a/tests/embedder/test_huggingface_embedder.py b/tests/embedder/test_huggingface_embedder.py new file mode 100644 index 00000000..4760b926 --- /dev/null +++ b/tests/embedder/test_huggingface_embedder.py @@ -0,0 +1,18 @@ + +from unittest.mock import patch +from embedchain.config import BaseEmbedderConfig +from embedchain.embedder.huggingface import HuggingFaceEmbedder + + +def test_huggingface_embedder_with_model(monkeypatch): + config = BaseEmbedderConfig(model="test-model", model_kwargs={"param": "value"}) + with patch('embedchain.embedder.huggingface.HuggingFaceEmbeddings') as mock_embeddings: + embedder = HuggingFaceEmbedder(config=config) + assert embedder.config.model == "test-model" + assert embedder.config.model_kwargs == {"param": "value"} + mock_embeddings.assert_called_once_with( + model_name="test-model", + model_kwargs={"param": "value"} + ) + +