54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
import pytest
|
|
|
|
from embedchain.llm.base import BaseLlm, BaseLlmConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def base_llm():
|
|
config = BaseLlmConfig()
|
|
return BaseLlm(config=config)
|
|
|
|
|
|
def test_is_get_llm_model_answer_not_implemented(base_llm):
|
|
with pytest.raises(NotImplementedError):
|
|
base_llm.get_llm_model_answer()
|
|
|
|
|
|
def test_is_get_llm_model_answer_implemented():
|
|
class TestLlm(BaseLlm):
|
|
def get_llm_model_answer(self):
|
|
return "Implemented"
|
|
|
|
config = BaseLlmConfig()
|
|
llm = TestLlm(config=config)
|
|
assert llm.get_llm_model_answer() == "Implemented"
|
|
|
|
|
|
def test_stream_query_response(base_llm):
|
|
answer = ["Chunk1", "Chunk2", "Chunk3"]
|
|
result = list(base_llm._stream_query_response(answer))
|
|
assert result == answer
|
|
|
|
|
|
def test_stream_chat_response(base_llm):
|
|
answer = ["Chunk1", "Chunk2", "Chunk3"]
|
|
result = list(base_llm._stream_chat_response(answer))
|
|
assert result == answer
|
|
|
|
|
|
def test_append_search_and_context(base_llm):
|
|
context = "Context"
|
|
web_search_result = "Web Search Result"
|
|
result = base_llm._append_search_and_context(context, web_search_result)
|
|
expected_result = "Context\nWeb Search Result: Web Search Result"
|
|
assert result == expected_result
|
|
|
|
|
|
def test_access_search_and_get_results(base_llm, mocker):
|
|
base_llm.access_search_and_get_results = mocker.patch.object(
|
|
base_llm, "access_search_and_get_results", return_value="Search Results"
|
|
)
|
|
input_query = "Test query"
|
|
result = base_llm.access_search_and_get_results(input_query)
|
|
assert result == "Search Results"
|