Add model_kwargs to OpenAI call (#1402)

This commit is contained in:
Pranav Puranik
2024-06-11 13:20:04 -05:00
committed by GitHub
parent 4119040005
commit 6ecdadfd97
3 changed files with 25 additions and 3 deletions

View File

@@ -26,6 +26,9 @@ llm:
top_p: 1 top_p: 1
stream: false stream: false
api_key: sk-xxx api_key: sk-xxx
model_kwargs:
response_format:
type: json_object
prompt: | prompt: |
Use the following pieces of context to answer the query at the end. Use the following pieces of context to answer the query at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer. If you don't know the answer, just say that you don't know, don't try to make up an answer.
@@ -83,7 +86,8 @@ cache:
"stream": false, "stream": false,
"prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:", "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
"system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.", "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
"api_key": "sk-xxx" "api_key": "sk-xxx",
"model_kwargs": {"response_format": {"type": "json_object"}}
} }
}, },
"vectordb": { "vectordb": {
@@ -143,7 +147,8 @@ config = {
'system_prompt': ( 'system_prompt': (
"Act as William Shakespeare. Answer the following questions in the style of William Shakespeare." "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare."
), ),
'api_key': 'sk-xxx' 'api_key': 'sk-xxx',
"model_kwargs": {"response_format": {"type": "json_object"}}
} }
}, },
'vectordb': { 'vectordb': {

View File

@@ -36,7 +36,7 @@ class OpenAILlm(BaseLlm):
"model": config.model or "gpt-3.5-turbo", "model": config.model or "gpt-3.5-turbo",
"temperature": config.temperature, "temperature": config.temperature,
"max_tokens": config.max_tokens, "max_tokens": config.max_tokens,
"model_kwargs": {}, "model_kwargs": config.model_kwargs or {},
} }
api_key = config.api_key or os.environ["OPENAI_API_KEY"] api_key = config.api_key or os.environ["OPENAI_API_KEY"]
base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None) base_url = config.base_url or os.environ.get("OPENAI_API_BASE", None)

View File

@@ -96,6 +96,23 @@ def test_get_llm_model_answer_with_special_headers(config, mocker):
) )
def test_get_llm_model_answer_with_model_kwargs(config, mocker):
config.model_kwargs = {"response_format": {"type": "json_object"}}
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
llm = OpenAILlm(config)
llm.get_llm_model_answer("Test query")
mocked_openai_chat.assert_called_once_with(
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_kwargs={"top_p": config.top_p, "response_format": {"type": "json_object"}},
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mock_return, expected", "mock_return, expected",
[ [