Improve and add more tests (#807)

This commit is contained in:
Sidharth Mohanty
2023-10-18 02:36:47 +05:30
committed by GitHub
parent d065cbf934
commit e8a2846449
11 changed files with 403 additions and 260 deletions

47
tests/llm/test_llama2.py Normal file
View File

@@ -0,0 +1,47 @@
import os
import pytest
from embedchain.llm.llama2 import Llama2Llm
@pytest.fixture
def llama2_llm():
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
llm = Llama2Llm()
return llm
def test_init_raises_value_error_without_api_key(mocker):
mocker.patch.dict(os.environ, clear=True)
with pytest.raises(ValueError):
Llama2Llm()
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
llama2_llm.config.system_prompt = "system_prompt"
with pytest.raises(ValueError):
llama2_llm.get_llm_model_answer("prompt")
def test_get_llm_model_answer(llama2_llm, mocker):
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
mocked_replicate_instance = mocker.MagicMock()
mocked_replicate.return_value = mocked_replicate_instance
mocked_replicate_instance.return_value = "Test answer"
llama2_llm.config.model = "test_model"
llama2_llm.config.max_tokens = 50
llama2_llm.config.temperature = 0.7
llama2_llm.config.top_p = 0.8
answer = llama2_llm.get_llm_model_answer("Test query")
assert answer == "Test answer"
mocked_replicate.assert_called_once_with(
model="test_model",
input={
"temperature": 0.7,
"max_length": 50,
"top_p": 0.8,
},
)
mocked_replicate_instance.assert_called_once_with("Test query")