OpenAI function calling support (#1011)
This commit is contained in:
@@ -50,24 +50,24 @@ def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
|
||||
mocked_openai_chat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_openai_chat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
config.system_prompt = None
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once_with(
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
|
||||
Reference in New Issue
Block a user