From 6ecdadfd977f1c5416a9c7057962b2a89a0d9163 Mon Sep 17 00:00:00 2001 From: Pranav Puranik <54378813+PranavPuranik@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:20:04 -0500 Subject: [PATCH] Add model_kwargs to OpenAI call (#1402) --- docs/api-reference/advanced/configuration.mdx | 9 +++++++-- embedchain/llm/openai.py | 2 +- tests/llm/test_openai.py | 17 +++++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index 0e93225e..cc91ecbe 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -26,6 +26,9 @@ llm: top_p: 1 stream: false api_key: sk-xxx + model_kwargs: + response_format: + type: json_object prompt: | Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -83,7 +86,8 @@ cache: "stream": false, "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:", "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.", - "api_key": "sk-xxx" + "api_key": "sk-xxx", + "model_kwargs": {"response_format": {"type": "json_object"}} } }, "vectordb": { @@ -143,7 +147,8 @@ config = { 'system_prompt': ( "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare." ), - 'api_key': 'sk-xxx' + 'api_key': 'sk-xxx', + "model_kwargs": {"response_format": {"type": "json_object"}} } }, 'vectordb': { diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index aec91b2a..c69ca39c 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -36,7 +36,7 @@ class OpenAILlm(BaseLlm): "model": config.model or "gpt-3.5-turbo", "temperature": config.temperature, "max_tokens": config.max_tokens, - "model_kwargs": {}, + "model_kwargs": config.model_kwargs or {}, } api_key = config.api_key or os.environ["OPENAI_API_KEY"] base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None) diff --git a/tests/llm/test_openai.py b/tests/llm/test_openai.py index 339b4757..b7c886b2 100644 --- a/tests/llm/test_openai.py +++ b/tests/llm/test_openai.py @@ -96,6 +96,23 @@ def test_get_llm_model_answer_with_special_headers(config, mocker): ) +def test_get_llm_model_answer_with_model_kwargs(config, mocker): + config.model_kwargs = {"response_format": {"type": "json_object"}} + mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI") + + llm = OpenAILlm(config) + llm.get_llm_model_answer("Test query") + + mocked_openai_chat.assert_called_once_with( + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}}, + api_key=os.environ["OPENAI_API_KEY"], + base_url=os.environ["OPENAI_API_BASE"], + ) + + @pytest.mark.parametrize( "mock_return, expected", [