From f6ddd5ffc509d0a7d2f8db1fc321bb38a25e6635 Mon Sep 17 00:00:00 2001 From: Dev Khant Date: Fri, 21 Jun 2024 21:27:21 +0530 Subject: [PATCH] Add HF endpoint in embedder (#1436) --- docs/api-reference/advanced/configuration.mdx | 1 + embedchain/config/embedder/base.py | 10 ++++++++++ embedchain/embedder/huggingface.py | 16 +++++++++++++++- embedchain/utils/misc.py | 3 ++- 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index 9d24f2a2..da90a5da 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -224,6 +224,7 @@ Alright, let's dive into what each key means in the yaml config above: - `model` (String): The specific model used for text embedding, 'text-embedding-ada-002'. - `vector_dimension` (Integer): The vector dimension of the embedding model. [Defaults](https://github.com/embedchain/embedchain/blob/main/embedchain/models/vector_dimensions.py) - `api_key` (String): The API key for the embedding model. + - `endpoint` (String): The endpoint for the HuggingFace embedding model. - `deployment_name` (String): The deployment name for the embedding model. - `title` (String): The title for the embedding model for Google Embedder. - `task_type` (String): The task type for the embedding model for Google Embedder. diff --git a/embedchain/config/embedder/base.py b/embedchain/config/embedder/base.py index f4229183..073be151 100644 --- a/embedchain/config/embedder/base.py +++ b/embedchain/config/embedder/base.py @@ -10,6 +10,7 @@ class BaseEmbedderConfig: model: Optional[str] = None, deployment_name: Optional[str] = None, vector_dimension: Optional[int] = None, + endpoint: Optional[str] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, ): @@ -20,9 +21,18 @@ class BaseEmbedderConfig: :type model: Optional[str], optional :param deployment_name: deployment name for llm embedding model, defaults to None :type deployment_name: Optional[str], optional + :param vector_dimension: vector dimension of the embedding model, defaults to None + :type vector_dimension: Optional[int], optional + :param endpoint: endpoint for the embedding model, defaults to None + :type endpoint: Optional[str], optional + :param api_key: hugginface api key, defaults to None + :type api_key: Optional[str], optional + :param api_base: huggingface api base, defaults to None + :type api_base: Optional[str], optional """ self.model = model self.deployment_name = deployment_name self.vector_dimension = vector_dimension + self.endpoint = endpoint self.api_key = api_key self.api_base = api_base diff --git a/embedchain/embedder/huggingface.py b/embedchain/embedder/huggingface.py index 88bc2890..51c3322b 100644 --- a/embedchain/embedder/huggingface.py +++ b/embedchain/embedder/huggingface.py @@ -1,6 +1,8 @@ +import os from typing import Optional from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings from embedchain.config import BaseEmbedderConfig from embedchain.embedder.base import BaseEmbedder @@ -11,7 +13,19 @@ class HuggingFaceEmbedder(BaseEmbedder): def __init__(self, config: Optional[BaseEmbedderConfig] = None): super().__init__(config=config) - embeddings = HuggingFaceEmbeddings(model_name=self.config.model) + if self.config.endpoint: + if not self.config.api_key and "HUGGINGFACE_ACCESS_TOKEN" not in os.environ: + raise ValueError( + "Please set the HUGGINGFACE_ACCESS_TOKEN environment variable or pass API Key in the config." + ) + + embeddings = HuggingFaceInferenceAPIEmbeddings( + model_name=self.config.model, + api_url=self.config.endpoint, + api_key=self.config.api_key or os.getenv("HUGGINGFACE_ACCESS_TOKEN"), + ) + else: + embeddings = HuggingFaceEmbeddings(model_name=self.config.model) embedding_fn = BaseEmbedder._langchain_default_concept(embeddings) self.set_embedding_fn(embedding_fn=embedding_fn) diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index ba1494bc..1e966328 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -441,7 +441,7 @@ def validate_config(config_data): Optional("local"): bool, Optional("base_url"): str, Optional("default_headers"): dict, - Optional("api_version"): Or(str, datetime.date) + Optional("api_version"): Or(str, datetime.date), }, }, Optional("vectordb"): { @@ -473,6 +473,7 @@ def validate_config(config_data): Optional("task_type"): str, Optional("vector_dimension"): int, Optional("base_url"): str, + Optional("endpoint"): str, }, }, Optional("embedding_model"): {