Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
40
embedchain/tests/llm/test_llama2.py
Normal file
40
embedchain/tests/llm/test_llama2.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from embedchain.llm.llama2 import Llama2Llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama2_llm():
|
||||
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
|
||||
llm = Llama2Llm()
|
||||
return llm
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
Llama2Llm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
|
||||
llama2_llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llama2_llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(llama2_llm, mocker):
|
||||
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
|
||||
mocked_replicate_instance = mocker.MagicMock()
|
||||
mocked_replicate.return_value = mocked_replicate_instance
|
||||
mocked_replicate_instance.invoke.return_value = "Test answer"
|
||||
|
||||
llama2_llm.config.model = "test_model"
|
||||
llama2_llm.config.max_tokens = 50
|
||||
llama2_llm.config.temperature = 0.7
|
||||
llama2_llm.config.top_p = 0.8
|
||||
|
||||
answer = llama2_llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
Reference in New Issue
Block a user