[Feature] OpenAI Function Calling (#1224)
This commit is contained in:
@@ -24,7 +24,7 @@ def test_get_llm_model_answer(anthropic_llm):
|
||||
|
||||
|
||||
def test_get_answer(anthropic_llm):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
with patch("langchain_community.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
@@ -53,7 +53,7 @@ def test_get_messages(anthropic_llm):
|
||||
|
||||
|
||||
def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatAnthropic") as mock_chat:
|
||||
with patch("langchain_community.chat_models.ChatAnthropic") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_get_llm_model_answer(azure_openai_llm):
|
||||
|
||||
|
||||
def test_get_answer(azure_openai_llm):
|
||||
with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
|
||||
with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_get_messages(azure_openai_llm):
|
||||
|
||||
|
||||
def test_get_answer_top_p_is_provided(azure_openai_llm, caplog):
|
||||
with patch("langchain.chat_models.AzureChatOpenAI") as mock_chat:
|
||||
with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from langchain.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||
from langchain_community.llms.gpt4all import GPT4All as LangchainGPT4All
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.gpt4all import GPT4ALLLlm
|
||||
|
||||
@@ -74,3 +74,32 @@ def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mock_return, expected",
|
||||
[
|
||||
([{"test": "test"}], '{"test": "test"}'),
|
||||
([], "Input could not be mapped to the function!"),
|
||||
],
|
||||
)
|
||||
def test_get_llm_model_answer_with_tools(config, mocker, mock_return, expected):
|
||||
mocked_openai_chat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
mocked_convert_to_openai_tool = mocker.patch("langchain_core.utils.function_calling.convert_to_openai_tool")
|
||||
mocked_json_output_tools_parser = mocker.patch("langchain.output_parsers.openai_tools.JsonOutputToolsParser")
|
||||
mocked_openai_chat.return_value.bind.return_value.pipe.return_value.invoke.return_value = mock_return
|
||||
|
||||
llm = OpenAILlm(config, tools={"test": "test"})
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_openai_chat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
)
|
||||
mocked_convert_to_openai_tool.assert_called_once_with({"test": "test"})
|
||||
mocked_json_output_tools_parser.assert_called_once()
|
||||
|
||||
assert answer == expected
|
||||
|
||||
@@ -22,7 +22,7 @@ def test_get_llm_model_answer(vertexai_llm):
|
||||
|
||||
|
||||
def test_get_answer_with_warning(vertexai_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
|
||||
with patch("langchain_community.chat_models.ChatVertexAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_get_answer_with_warning(vertexai_llm, caplog):
|
||||
|
||||
|
||||
def test_get_answer_no_warning(vertexai_llm, caplog):
|
||||
with patch("langchain.chat_models.ChatVertexAI") as mock_chat:
|
||||
with patch("langchain_community.chat_models.ChatVertexAI") as mock_chat:
|
||||
mock_chat_instance = mock_chat.return_value
|
||||
mock_chat_instance.return_value = MagicMock(content="Test Response")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user