Adding model_kwargs for huggingface embedders. (#1450)
This commit is contained in:
@@ -229,6 +229,7 @@ Alright, let's dive into what each key means in the yaml config above:
|
|||||||
- `deployment_name` (String): The deployment name for the embedding model.
|
- `deployment_name` (String): The deployment name for the embedding model.
|
||||||
- `title` (String): The title for the embedding model for Google Embedder.
|
- `title` (String): The title for the embedding model for Google Embedder.
|
||||||
- `task_type` (String): The task type for the embedding model for Google Embedder.
|
- `task_type` (String): The task type for the embedding model for Google Embedder.
|
||||||
|
- `model_kwargs` (Dict): Used to pass extra arguments to embedders.
|
||||||
5. `chunker` Section:
|
5. `chunker` Section:
|
||||||
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
|
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
|
||||||
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
|
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
|
||||||
|
|||||||
@@ -192,6 +192,8 @@ embedder:
|
|||||||
provider: huggingface
|
provider: huggingface
|
||||||
config:
|
config:
|
||||||
model: 'sentence-transformers/all-mpnet-base-v2'
|
model: 'sentence-transformers/all-mpnet-base-v2'
|
||||||
|
model_kwargs:
|
||||||
|
trust_remote_code: True # Only use if you trust your embedder
|
||||||
```
|
```
|
||||||
|
|
||||||
</CodeGroup>
|
</CodeGroup>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from embedchain.helpers.json_serializable import register_deserializable
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ class BaseEmbedderConfig:
|
|||||||
endpoint: Optional[str] = None,
|
endpoint: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new instance of an embedder config class.
|
Initialize a new instance of an embedder config class.
|
||||||
@@ -29,6 +30,8 @@ class BaseEmbedderConfig:
|
|||||||
:type api_key: Optional[str], optional
|
:type api_key: Optional[str], optional
|
||||||
:param api_base: huggingface api base, defaults to None
|
:param api_base: huggingface api base, defaults to None
|
||||||
:type api_base: Optional[str], optional
|
:type api_base: Optional[str], optional
|
||||||
|
:param model_kwargs: key-value arguments for the embedding model, defaults a dict inside init.
|
||||||
|
:type model_kwargs: Optional[Dict[str, Any]], defaults a dict inside init.
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.deployment_name = deployment_name
|
self.deployment_name = deployment_name
|
||||||
@@ -36,3 +39,4 @@ class BaseEmbedderConfig:
|
|||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.api_base = api_base
|
self.api_base = api_base
|
||||||
|
self.model_kwargs = model_kwargs or {}
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ class HuggingFaceEmbedder(BaseEmbedder):
|
|||||||
huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
|
huggingfacehub_api_token=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
|
embeddings = HuggingFaceEmbeddings(model_name=self.config.model, model_kwargs=self.config.model_kwargs)
|
||||||
|
|
||||||
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
|
||||||
self.set_embedding_fn(embedding_fn=embedding_fn)
|
self.set_embedding_fn(embedding_fn=embedding_fn)
|
||||||
|
|
||||||
|
|||||||
@@ -474,6 +474,7 @@ def validate_config(config_data):
|
|||||||
Optional("vector_dimension"): int,
|
Optional("vector_dimension"): int,
|
||||||
Optional("base_url"): str,
|
Optional("base_url"): str,
|
||||||
Optional("endpoint"): str,
|
Optional("endpoint"): str,
|
||||||
|
Optional("model_kwargs"): dict,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Optional("embedding_model"): {
|
Optional("embedding_model"): {
|
||||||
|
|||||||
18
tests/embedder/test_huggingface_embedder.py
Normal file
18
tests/embedder/test_huggingface_embedder.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
from embedchain.config import BaseEmbedderConfig
|
||||||
|
from embedchain.embedder.huggingface import HuggingFaceEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
def test_huggingface_embedder_with_model(monkeypatch):
|
||||||
|
config = BaseEmbedderConfig(model="test-model", model_kwargs={"param": "value"})
|
||||||
|
with patch('embedchain.embedder.huggingface.HuggingFaceEmbeddings') as mock_embeddings:
|
||||||
|
embedder = HuggingFaceEmbedder(config=config)
|
||||||
|
assert embedder.config.model == "test-model"
|
||||||
|
assert embedder.config.model_kwargs == {"param": "value"}
|
||||||
|
mock_embeddings.assert_called_once_with(
|
||||||
|
model_name="test-model",
|
||||||
|
model_kwargs={"param": "value"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user