Adding model_kwargs for huggingface embedders. (#1450)
This commit is contained in:
@@ -229,6 +229,7 @@ Alright, let's dive into what each key means in the yaml config above:
|
||||
- `deployment_name` (String): The deployment name for the embedding model.
|
||||
- `title` (String): The title for the embedding model for Google Embedder.
|
||||
- `task_type` (String): The task type for the embedding model for Google Embedder.
|
||||
- `model_kwargs` (Dict): Used to pass extra arguments to embedders.
|
||||
5. `chunker` Section:
|
||||
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
|
||||
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
|
||||
|
||||
@@ -192,6 +192,8 @@ embedder:
|
||||
provider: huggingface
|
||||
config:
|
||||
model: 'sentence-transformers/all-mpnet-base-v2'
|
||||
model_kwargs:
|
||||
trust_remote_code: True # Only use if you trust your embedder
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
@@ -13,6 +13,7 @@ class BaseEmbedderConfig:
|
||||
endpoint: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new instance of an embedder config class.
|
||||
@@ -29,6 +30,8 @@ class BaseEmbedderConfig:
|
||||
:type api_key: Optional[str], optional
|
||||
:param api_base: huggingface api base, defaults to None
|
||||
:type api_base: Optional[str], optional
|
||||
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
|
||||
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
|
||||
"""
|
||||
self.model = model
|
||||
self.deployment_name = deployment_name
|
||||
@@ -36,3 +39,4 @@ class BaseEmbedderConfig:
|
||||
self.endpoint = endpoint
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.model_kwargs = model_kwargs or {}
|
||||
|
||||
@@ -31,7 +31,8 @@ class HuggingFaceEmbedder(BaseEmbedder):
|
||||
huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
|
||||
)
|
||||
else:
|
||||
embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
|
||||
embeddings = HuggingFaceEmbeddings(model_name=self.config.model, model_kwargs=self.config.model_kwargs)
|
||||
|
||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||
|
||||
|
||||
@@ -474,6 +474,7 @@ def validate_config(config_data):
|
||||
Optional("vector_dimension"): int,
|
||||
Optional("base_url"): str,
|
||||
Optional("endpoint"): str,
|
||||
Optional("model_kwargs"): dict,
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
|
||||
18
tests/embedder/test_huggingface_embedder.py
Normal file
18
tests/embedder/test_huggingface_embedder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
from embedchain.config import BaseEmbedderConfig
|
||||
from embedchain.embedder.huggingface import HuggingFaceEmbedder
|
||||
|
||||
|
||||
def test_huggingface_embedder_with_model(monkeypatch):
|
||||
config = BaseEmbedderConfig(model="test-model", model_kwargs={"param": "value"})
|
||||
with patch('embedchain.embedder.huggingface.HuggingFaceEmbeddings') as mock_embeddings:
|
||||
embedder = HuggingFaceEmbedder(config=config)
|
||||
assert embedder.config.model == "test-model"
|
||||
assert embedder.config.model_kwargs == {"param": "value"}
|
||||
mock_embeddings.assert_called_once_with(
|
||||
model_name="test-model",
|
||||
model_kwargs={"param": "value"}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user