Add HF endpoint in embedder (#1436)

This commit is contained in:
Dev Khant
2024-06-21 21:27:21 +05:30
committed by GitHub
parent b43a116b3c
commit f6ddd5ffc5
4 changed files with 28 additions and 2 deletions

View File

@@ -224,6 +224,7 @@ Alright, let's dive into what each key means in the yaml config above:
- `model` (String): The specific model used for text embedding, 'text-embedding-ada-002'.
- `vector_dimension` (Integer): The vector dimension of the embedding model. [Defaults](https://github.com/embedchain/embedchain/blob/main/embedchain/models/vector_dimensions.py)
- `api_key` (String): The API key for the embedding model.
- `endpoint` (String): The endpoint for the HuggingFace embedding model.
- `deployment_name` (String): The deployment name for the embedding model.
- `title` (String): The title for the embedding model for Google Embedder.
- `task_type` (String): The task type for the embedding model for Google Embedder.

View File

@@ -10,6 +10,7 @@ class BaseEmbedderConfig:
model: Optional[str] = None,
deployment_name: Optional[str] = None,
vector_dimension: Optional[int] = None,
endpoint: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
):
@@ -20,9 +21,18 @@ class BaseEmbedderConfig:
:type model: Optional[str], optional
:param deployment_name: deployment name for llm embedding model, defaults to None
:type deployment_name: Optional[str], optional
:param vector_dimension: vector dimension of the embedding model, defaults to None
:type vector_dimension: Optional[int], optional
:param endpoint: endpoint for the embedding model, defaults to None
:type endpoint: Optional[str], optional
:param api_key: hugginface api key, defaults to None
:type api_key: Optional[str], optional
:param api_base: huggingface api base, defaults to None
:type api_base: Optional[str], optional
"""
self.model = model
self.deployment_name = deployment_name
self.vector_dimension = vector_dimension
self.endpoint = endpoint
self.api_key = api_key
self.api_base = api_base

View File

@@ -1,6 +1,8 @@
import os
from typing import Optional
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
@@ -11,6 +13,18 @@ class HuggingFaceEmbedder(BaseEmbedder):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config=config)
if self.config.endpoint:
if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ:
raise ValueError(
"Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass API Key in the config."
)
embeddings = HuggingFaceInferenceAPIEmbeddings(
model_name=self.config.model,
api_url=self.config.endpoint,
api_key=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
)
else:
embeddings = HuggingFaceEmbeddings(model_name=self.config.model)
embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
self.set_embedding_fn(embedding_fn=embedding_fn)

View File

@@ -441,7 +441,7 @@ def validate_config(config_data):
Optional("local"): bool,
Optional("base_url"): str,
Optional("default_headers"): dict,
Optional("api_version"): Or(str, datetime.date)
Optional("api_version"): Or(str, datetime.date),
},
},
Optional("vectordb"): {
@@ -473,6 +473,7 @@ def validate_config(config_data):
Optional("task_type"): str,
Optional("vector_dimension"): int,
Optional("base_url"): str,
Optional("endpoint"): str,
},
},
Optional("embedding_model"): {