diff --git a/embedchain/config/llm/base.py b/embedchain/config/llm/base.py index 1efe8acb..26b094c9 100644 --- a/embedchain/config/llm/base.py +++ b/embedchain/config/llm/base.py @@ -1,7 +1,7 @@ import logging import re from string import Template -from typing import Any, Optional +from typing import Any, Mapping, Optional from embedchain.config.base_config import BaseConfig from embedchain.helpers.json_serializable import register_deserializable @@ -99,6 +99,7 @@ class BaseLlmConfig(BaseConfig): endpoint: Optional[str] = None, model_kwargs: Optional[dict[str, Any]] = None, local: Optional[bool] = False, + default_headers: Optional[Mapping[str, str]] = None, ): """ Initializes a configuration class instance for the LLM. @@ -144,6 +145,8 @@ class BaseLlmConfig(BaseConfig): :type query_type: Optional[str], optional :param local: If True, the model will be run locally, defaults to False (for huggingface provider) :type local: Optional[bool], optional + :param default_headers: Set additional HTTP headers to be sent with requests to OpenAI + :type default_headers: Optional[Mapping[str, str]], optional :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history) :raises ValueError: Stream is not boolean @@ -173,6 +176,7 @@ class BaseLlmConfig(BaseConfig): self.endpoint = endpoint self.model_kwargs = model_kwargs self.local = local + self.default_headers = default_headers if isinstance(prompt, str): prompt = Template(prompt) diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 4432b5f8..c1b2d0af 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -42,6 +42,8 @@ class OpenAILlm(BaseLlm): base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None) if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p + if config.default_headers: + kwargs["default_headers"] = config.default_headers if config.stream: callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()] chat = ChatOpenAI( @@ -65,8 +67,7 @@ class OpenAILlm(BaseLlm): messages: list[BaseMessage], ) -> str: from langchain.output_parsers.openai_tools import JsonOutputToolsParser - from langchain_core.utils.function_calling import \ - convert_to_openai_tool + from langchain_core.utils.function_calling import convert_to_openai_tool openai_tools = [convert_to_openai_tool(tools)] chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser()) diff --git a/embedchain/utils/misc.py b/embedchain/utils/misc.py index 4b991116..6c046f43 100644 --- a/embedchain/utils/misc.py +++ b/embedchain/utils/misc.py @@ -431,6 +431,7 @@ def validate_config(config_data): Optional("model_kwargs"): dict, Optional("local"): bool, Optional("base_url"): str, + Optional("default_headers"): dict, }, }, Optional("vectordb"): { diff --git a/tests/llm/test_openai.py b/tests/llm/test_openai.py index d95f942a..51150ca7 100644 --- a/tests/llm/test_openai.py +++ b/tests/llm/test_openai.py @@ -76,6 +76,23 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker): base_url=os.environ["OPENAI_API_BASE"], ) +def test_get_llm_model_answer_with_special_headers(config, mocker): + config.default_headers = {'test': 'test'} + mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI") + + llm = OpenAILlm(config) + llm.get_llm_model_answer("Test query") + + mocked_openai_chat.assert_called_once_with( + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + model_kwargs={"top_p": config.top_p}, + api_key=os.environ["OPENAI_API_KEY"], + base_url=os.environ["OPENAI_API_BASE"], + default_headers={'test': 'test'} + ) + @pytest.mark.parametrize( "mock_return, expected",