Add model_kwargs to OpenAI call (#1402)

This commit is contained in:
Pranav Puranik
2024-06-11 13:20:04 -05:00
committed by GitHub
parent 4119040005
commit 6ecdadfd97
3 changed files with 25 additions and 3 deletions

View File

@@ -96,6 +96,23 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
)
def test_get_llm_model_answer_with_model_kwargs(config, mocker):
config.model_kwargs = {"response_format": {"type": "json_object"}}
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
)
@pytest.mark.parametrize(
"mock_return, expected",
[