Improve and add more tests (#807)
This commit is contained in:
@@ -13,13 +13,9 @@ class OpenAILlm(BaseLlm):
|
||||
def __init__(self, config: Optional[BaseLlmConfig] = None):
|
||||
super().__init__(config=config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
def get_llm_model_answer(self, prompt) -> str:
|
||||
response = OpenAILlm._get_answer(prompt, self.config)
|
||||
|
||||
if self.config.stream:
|
||||
return response
|
||||
else:
|
||||
return response.content
|
||||
return response
|
||||
|
||||
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
|
||||
messages = []
|
||||
@@ -35,10 +31,9 @@ class OpenAILlm(BaseLlm):
|
||||
if config.top_p:
|
||||
kwargs["model_kwargs"]["top_p"] = config.top_p
|
||||
if config.stream:
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
else:
|
||||
chat = ChatOpenAI(**kwargs)
|
||||
return chat(messages)
|
||||
return chat(messages).content
|
||||
|
||||
@@ -133,6 +133,7 @@ isort = "^5.12.0"
|
||||
pytest-cov = "^4.1.0"
|
||||
responses = "^0.23.3"
|
||||
mock = "^5.1.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
streamlit = ["streamlit"]
|
||||
|
||||
@@ -104,6 +104,19 @@ class TestConfigForAppComponents:
|
||||
|
||||
assert isinstance(embedder_config, BaseEmbedderConfig)
|
||||
|
||||
def test_components_raises_type_error_if_not_proper_instances(self):
|
||||
wrong_llm = "wrong_llm"
|
||||
with pytest.raises(TypeError):
|
||||
App(llm=wrong_llm)
|
||||
|
||||
wrong_db = "wrong_db"
|
||||
with pytest.raises(TypeError):
|
||||
App(db=wrong_db)
|
||||
|
||||
wrong_embedder = "wrong_embedder"
|
||||
with pytest.raises(TypeError):
|
||||
App(embedder=wrong_embedder)
|
||||
|
||||
|
||||
class TestAppFromConfig:
|
||||
def load_config_data(self, yaml_path):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from string import Template
|
||||
from embedchain.llm.base import BaseLlm, BaseLlmConfig
|
||||
|
||||
|
||||
@@ -14,6 +14,18 @@ def test_is_get_llm_model_answer_not_implemented(base_llm):
|
||||
base_llm.get_llm_model_answer()
|
||||
|
||||
|
||||
def test_is_stream_bool():
|
||||
with pytest.raises(ValueError):
|
||||
config = BaseLlmConfig(stream="test value")
|
||||
BaseLlm(config=config)
|
||||
|
||||
|
||||
def test_template_string_gets_converted_to_Template_instance():
|
||||
config = BaseLlmConfig(template="test value $query $context")
|
||||
llm = BaseLlm(config=config)
|
||||
assert isinstance(llm.config.template, Template)
|
||||
|
||||
|
||||
def test_is_get_llm_model_answer_implemented():
|
||||
class TestLlm(BaseLlm):
|
||||
def get_llm_model_answer(self):
|
||||
|
||||
@@ -1,33 +1,55 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.cohere import CohereLlm
|
||||
|
||||
|
||||
class TestCohereLlm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["COHERE_API_KEY"] = "test_api_key"
|
||||
self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
@pytest.fixture
|
||||
def cohere_llm_config():
|
||||
os.environ["COHERE_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
yield config
|
||||
os.environ.pop("COHERE_API_KEY")
|
||||
|
||||
def test_init_raises_value_error_without_api_key(self):
|
||||
os.environ.pop("COHERE_API_KEY")
|
||||
with self.assertRaises(ValueError):
|
||||
CohereLlm()
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
|
||||
llm = CohereLlm(self.config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with self.assertRaises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
CohereLlm()
|
||||
|
||||
@patch("embedchain.llm.cohere.CohereLlm._get_answer")
|
||||
def test_get_llm_model_answer(self, mock_get_answer):
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
|
||||
llm = CohereLlm(self.config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config):
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
mock_get_answer.assert_called_once()
|
||||
|
||||
def test_get_llm_model_answer(cohere_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_get_answer_mocked_cohere(cohere_llm_config, mocker):
|
||||
mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere")
|
||||
mock_instance = mocked_cohere.return_value
|
||||
mock_instance.return_value = "Mocked answer"
|
||||
|
||||
llm = CohereLlm(cohere_llm_config)
|
||||
prompt = "Test query"
|
||||
answer = llm.get_llm_model_answer(prompt)
|
||||
|
||||
assert answer == "Mocked answer"
|
||||
mocked_cohere.assert_called_once_with(
|
||||
cohere_api_key="test_api_key",
|
||||
model="gptd-instruct-tft",
|
||||
max_tokens=50,
|
||||
temperature=0.7,
|
||||
p=0.8,
|
||||
)
|
||||
mock_instance.assert_called_once_with(prompt)
|
||||
|
||||
@@ -1,64 +1,61 @@
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.huggingface import HuggingFaceLlm
|
||||
|
||||
|
||||
class TestHuggingFaceLlm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
|
||||
self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
@pytest.fixture
|
||||
def huggingface_llm_config():
|
||||
os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token"
|
||||
config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8)
|
||||
yield config
|
||||
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
|
||||
|
||||
def test_init_raises_value_error_without_api_key(self):
|
||||
os.environ.pop("HUGGINGFACE_ACCESS_TOKEN")
|
||||
with self.assertRaises(ValueError):
|
||||
HuggingFaceLlm()
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(self):
|
||||
llm = HuggingFaceLlm(self.config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with self.assertRaises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceLlm()
|
||||
|
||||
def test_top_p_value_within_range(self):
|
||||
config = BaseLlmConfig(top_p=1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
HuggingFaceLlm._get_answer("test_prompt", config)
|
||||
|
||||
def test_dependency_is_imported(self):
|
||||
importlib_installed = True
|
||||
try:
|
||||
importlib.import_module("huggingface_hub")
|
||||
except ImportError:
|
||||
importlib_installed = False
|
||||
self.assertTrue(importlib_installed)
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config):
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llm.get_llm_model_answer("prompt")
|
||||
|
||||
@patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer")
|
||||
def test_get_llm_model_answer(self, mock_get_answer):
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
|
||||
llm = HuggingFaceLlm(self.config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
def test_top_p_value_within_range():
|
||||
config = BaseLlmConfig(top_p=1.0)
|
||||
with pytest.raises(ValueError):
|
||||
HuggingFaceLlm._get_answer("test_prompt", config)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
mock_get_answer.assert_called_once()
|
||||
|
||||
@patch("embedchain.llm.huggingface.HuggingFaceHub")
|
||||
def test_hugging_face_mock(self, mock_huggingface):
|
||||
mock_llm_instance = MagicMock()
|
||||
mock_llm_instance.return_value = "Test answer"
|
||||
mock_huggingface.return_value = mock_llm_instance
|
||||
def test_dependency_is_imported():
|
||||
importlib_installed = True
|
||||
try:
|
||||
importlib.import_module("huggingface_hub")
|
||||
except ImportError:
|
||||
importlib_installed = False
|
||||
assert importlib_installed
|
||||
|
||||
llm = HuggingFaceLlm(self.config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
mock_huggingface.assert_called_once_with(
|
||||
huggingfacehub_api_token="test_access_token",
|
||||
repo_id="google/flan-t5-xxl",
|
||||
model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8},
|
||||
)
|
||||
mock_llm_instance.assert_called_once_with("Test query")
|
||||
def test_get_llm_model_answer(huggingface_llm_config, mocker):
|
||||
mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
def test_hugging_face_mock(huggingface_llm_config, mocker):
|
||||
mock_llm_instance = mocker.Mock(return_value="Test answer")
|
||||
mocker.patch("embedchain.llm.huggingface.HuggingFaceHub", return_value=mock_llm_instance)
|
||||
|
||||
llm = HuggingFaceLlm(huggingface_llm_config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mock_llm_instance.assert_called_once_with("Test query")
|
||||
|
||||
@@ -1,40 +1,76 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.jina import JinaLlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
class TestJinaLlm(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["JINACHAT_API_KEY"] = "test_api_key"
|
||||
self.config = BaseLlmConfig(
|
||||
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt"
|
||||
)
|
||||
@pytest.fixture
|
||||
def config():
|
||||
os.environ["JINACHAT_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt")
|
||||
yield config
|
||||
os.environ.pop("JINACHAT_API_KEY")
|
||||
|
||||
def test_init_raises_value_error_without_api_key(self):
|
||||
os.environ.pop("JINACHAT_API_KEY")
|
||||
with self.assertRaises(ValueError):
|
||||
JinaLlm()
|
||||
|
||||
@patch("embedchain.llm.jina.JinaLlm._get_answer")
|
||||
def test_get_llm_model_answer(self, mock_get_answer):
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
JinaLlm()
|
||||
|
||||
llm = JinaLlm(self.config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
mock_get_answer.assert_called_once()
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
@patch("embedchain.llm.jina.JinaLlm._get_answer")
|
||||
def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer):
|
||||
self.config.system_prompt = "Custom system prompt"
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
llm = JinaLlm(self.config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
mock_get_answer.assert_called_once()
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(config, mocker):
|
||||
config.system_prompt = "Custom system prompt"
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
config.system_prompt = None
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat")
|
||||
|
||||
llm = JinaLlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once_with(
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
)
|
||||
|
||||
47
tests/llm/test_llama2.py
Normal file
47
tests/llm/test_llama2.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
import pytest
|
||||
from embedchain.llm.llama2 import Llama2Llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama2_llm():
|
||||
os.environ["REPLICATE_API_TOKEN"] = "test_api_token"
|
||||
llm = Llama2Llm()
|
||||
return llm
|
||||
|
||||
|
||||
def test_init_raises_value_error_without_api_key(mocker):
|
||||
mocker.patch.dict(os.environ, clear=True)
|
||||
with pytest.raises(ValueError):
|
||||
Llama2Llm()
|
||||
|
||||
|
||||
def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm):
|
||||
llama2_llm.config.system_prompt = "system_prompt"
|
||||
with pytest.raises(ValueError):
|
||||
llama2_llm.get_llm_model_answer("prompt")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(llama2_llm, mocker):
|
||||
mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate")
|
||||
mocked_replicate_instance = mocker.MagicMock()
|
||||
mocked_replicate.return_value = mocked_replicate_instance
|
||||
mocked_replicate_instance.return_value = "Test answer"
|
||||
|
||||
llama2_llm.config.model = "test_model"
|
||||
llama2_llm.config.max_tokens = 50
|
||||
llama2_llm.config.temperature = 0.7
|
||||
llama2_llm.config.top_p = 0.8
|
||||
|
||||
answer = llama2_llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_replicate.assert_called_once_with(
|
||||
model="test_model",
|
||||
input={
|
||||
"temperature": 0.7,
|
||||
"max_length": 50,
|
||||
"top_p": 0.8,
|
||||
},
|
||||
)
|
||||
mocked_replicate_instance.assert_called_once_with("Test query")
|
||||
73
tests/llm/test_openai.py
Normal file
73
tests/llm/test_openai.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
import pytest
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
config = BaseLlmConfig(
|
||||
temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo"
|
||||
)
|
||||
yield config
|
||||
os.environ.pop("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def test_get_llm_model_answer(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_system_prompt(config, mocker):
|
||||
config.system_prompt = "Custom system prompt"
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("Test query", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_empty_prompt(config, mocker):
|
||||
mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
answer = llm.get_llm_model_answer("")
|
||||
|
||||
assert answer == "Test answer"
|
||||
mocked_get_answer.assert_called_once_with("", config)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_with_streaming(config, mocker):
|
||||
config.stream = True
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once()
|
||||
callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list]
|
||||
assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks)
|
||||
|
||||
|
||||
def test_get_llm_model_answer_without_system_prompt(config, mocker):
|
||||
config.system_prompt = None
|
||||
mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI")
|
||||
|
||||
llm = OpenAILlm(config)
|
||||
llm.get_llm_model_answer("Test query")
|
||||
|
||||
mocked_jinachat.assert_called_once_with(
|
||||
model=config.model,
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
model_kwargs={"top_p": config.top_p},
|
||||
)
|
||||
@@ -1,135 +1,85 @@
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
@pytest.fixture
|
||||
def app():
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
app = App(config=AppConfig(collect_metrics=False))
|
||||
return app
|
||||
|
||||
def setUp(self):
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
|
||||
'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
|
||||
appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
_answer = self.app.query(input_query="Test query")
|
||||
|
||||
# Ensure retrieve_from_database was called
|
||||
mock_retrieve.assert_called_once()
|
||||
|
||||
# Check the call arguments
|
||||
args, kwargs = mock_retrieve.call_args
|
||||
input_query_arg = kwargs.get("input_query")
|
||||
self.assertEqual(input_query_arg, "Test query")
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_query_config_app_passing(self, mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value.content = "Test answer"
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
|
||||
app = App(config=config, llm_config=chat_config)
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
|
||||
self.assertEqual(answer, "Test answer")
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_app_passing(self, mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value.content = "Test answer"
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig()
|
||||
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
|
||||
self.assertEqual(answer, "Test answer")
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_retrieve.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_query_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
BaseLlmConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
|
||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query(app):
|
||||
with patch.object(app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(self.app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
llm_config = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = self.app.query("Test query", llm_config)
|
||||
answer = app.query(input_query="Test query")
|
||||
assert answer == "Test answer"
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
_args, kwargs = mock_database_query.call_args
|
||||
self.assertEqual(kwargs.get("input_query"), "Test query")
|
||||
self.assertEqual(kwargs.get("where"), {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
mock_retrieve.assert_called_once()
|
||||
_, kwargs = mock_retrieve.call_args
|
||||
input_query_arg = kwargs.get("input_query")
|
||||
assert input_query_arg == "Test query"
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_query_config_app_passing(mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
|
||||
app = App(config=config, llm_config=chat_config)
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("embedchain.llm.openai.OpenAILlm._get_answer")
|
||||
def test_app_passing(mock_get_answer):
|
||||
mock_get_answer.return_value = MagicMock()
|
||||
mock_get_answer.return_value = "Test answer"
|
||||
config = AppConfig(collect_metrics=False)
|
||||
chat_config = BaseLlmConfig()
|
||||
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
|
||||
answer = app.llm.get_llm_model_answer("Test query")
|
||||
assert app.llm.config.system_prompt == "Test system prompt"
|
||||
assert answer == "Test answer"
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(app):
|
||||
with patch.object(app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_retrieve.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
assert kwargs.get("where") == {"attribute": "value"}
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_query_config(app):
|
||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
with patch.object(app.db, "query") as mock_database_query:
|
||||
mock_database_query.return_value = ["Test context"]
|
||||
llm_config = BaseLlmConfig(where={"attribute": "value"})
|
||||
answer = app.query("Test query", llm_config)
|
||||
|
||||
assert answer == "Test answer"
|
||||
_, kwargs = mock_database_query.call_args
|
||||
assert kwargs.get("input_query") == "Test query"
|
||||
assert kwargs.get("where") == {"attribute": "value"}
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@@ -1,32 +1,29 @@
|
||||
import unittest
|
||||
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
|
||||
|
||||
class TestDataTypeEnums(unittest.TestCase):
|
||||
def test_subclass_types_in_data_type(self):
|
||||
"""Test that all data type category subclasses are contained in the composite data type"""
|
||||
# Check if DirectDataType values are in DataType
|
||||
for data_type in DirectDataType:
|
||||
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||
def test_subclass_types_in_data_type():
|
||||
"""Test that all data type category subclasses are contained in the composite data type"""
|
||||
# Check if DirectDataType values are in DataType
|
||||
for data_type in DirectDataType:
|
||||
assert data_type.value in DataType._value2member_map_
|
||||
|
||||
# Check if IndirectDataType values are in DataType
|
||||
for data_type in IndirectDataType:
|
||||
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||
# Check if IndirectDataType values are in DataType
|
||||
for data_type in IndirectDataType:
|
||||
assert data_type.value in DataType._value2member_map_
|
||||
|
||||
# Check if SpecialDataType values are in DataType
|
||||
for data_type in SpecialDataType:
|
||||
self.assertIn(data_type.value, DataType._value2member_map_)
|
||||
# Check if SpecialDataType values are in DataType
|
||||
for data_type in SpecialDataType:
|
||||
assert data_type.value in DataType._value2member_map_
|
||||
|
||||
def test_data_type_in_subclasses(self):
|
||||
"""Test that all data types in the composite data type are categorized in a subclass"""
|
||||
for data_type in DataType:
|
||||
if data_type.value in DirectDataType._value2member_map_:
|
||||
self.assertIn(data_type.value, DirectDataType._value2member_map_)
|
||||
elif data_type.value in IndirectDataType._value2member_map_:
|
||||
self.assertIn(data_type.value, IndirectDataType._value2member_map_)
|
||||
elif data_type.value in SpecialDataType._value2member_map_:
|
||||
self.assertIn(data_type.value, SpecialDataType._value2member_map_)
|
||||
else:
|
||||
self.fail(f"{data_type.value} not found in any subclass enums")
|
||||
|
||||
def test_data_type_in_subclasses():
|
||||
"""Test that all data types in the composite data type are categorized in a subclass"""
|
||||
for data_type in DataType:
|
||||
if data_type.value in DirectDataType._value2member_map_:
|
||||
assert data_type.value in DirectDataType._value2member_map_
|
||||
elif data_type.value in IndirectDataType._value2member_map_:
|
||||
assert data_type.value in IndirectDataType._value2member_map_
|
||||
elif data_type.value in SpecialDataType._value2member_map_:
|
||||
assert data_type.value in SpecialDataType._value2member_map_
|
||||
else:
|
||||
assert False, f"{data_type.value} not found in any subclass enums"
|
||||
|
||||
Reference in New Issue
Block a user