Support supplying custom headers to OpenAI requests (#1356)

This commit is contained in:
Niv Hertz
2024-05-06 13:26:12 -04:00
committed by GitHub
parent a0ff764f0a
commit 797dea1dca
4 changed files with 26 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
import logging
import re
from string import Template
from typing import Any, Optional
from typing import Any, Mapping, Optional
from embedchain.config.base_config import BaseConfig
from embedchain.helpers.json_serializable import register_deserializable
@@ -99,6 +99,7 @@ class BaseLlmConfig(BaseConfig):
endpoint: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None,
local: Optional[bool] = False,
default_headers: Optional[Mapping[str, str]] = None,
):
"""
Initializes a configuration class instance for the LLM.
@@ -144,6 +145,8 @@ class BaseLlmConfig(BaseConfig):
:type query_type: Optional[str], optional
:param local: If True, the model will be run locally, defaults to False (for huggingface provider)
:type local: Optional[bool], optional
:param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
:type default_headers: Optional[Mapping[str, str]], optional
:raises ValueError: If the template is not valid as template should
contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean
@@ -173,6 +176,7 @@ class BaseLlmConfig(BaseConfig):
self.endpoint = endpoint
self.model_kwargs = model_kwargs
self.local = local
self.default_headers = default_headers
if isinstance(prompt, str):
prompt = Template(prompt)

View File

@@ -42,6 +42,8 @@ class OpenAILlm(BaseLlm):
base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)
if config.top_p:
kwargs["model_kwargs"]["top_p"] = config.top_p
if config.default_headers:
kwargs["default_headers"] = config.default_headers
if config.stream:
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
chat = ChatOpenAI(
@@ -65,8 +67,7 @@ class OpenAILlm(BaseLlm):
messages: list[BaseMessage],
) -> str:
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.utils.function_calling import \
convert_to_openai_tool
from langchain_core.utils.function_calling import convert_to_openai_tool
openai_tools = [convert_to_openai_tool(tools)]
chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

View File

@@ -431,6 +431,7 @@ def validate_config(config_data):
Optional("model_kwargs"): dict,
Optional("local"): bool,
Optional("base_url"): str,
Optional("default_headers"): dict,
},
},
Optional("vectordb"): {

View File

@@ -76,6 +76,23 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
base_url=os.environ["OPENAI_API_BASE"],
)
def test_get_llm_model_answer_with_special_headers(config, mocker):
config.default_headers = {'test': 'test'}
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
default_headers={'test': 'test'}
)
@pytest.mark.parametrize(
"mock_return, expected",