From e8a2846449b5b95350b8ca94536adc5531efbdb0 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Wed, 18 Oct 2023 02:36:47 +0530 Subject: [PATCH] Improve and add more tests (#807) --- embedchain/llm/openai.py | 13 +-- pyproject.toml | 1 + tests/apps/test_apps.py | 13 +++ tests/llm/test_base_llm.py | 14 ++- tests/llm/test_cohere.py | 66 +++++++---- tests/llm/test_huggingface.py | 93 ++++++++------- tests/llm/test_jina.py | 92 ++++++++++----- tests/llm/test_llama2.py | 47 ++++++++ tests/llm/test_openai.py | 73 ++++++++++++ tests/llm/test_query.py | 200 +++++++++++++-------------------- tests/models/test_data_type.py | 51 ++++----- 11 files changed, 403 insertions(+), 260 deletions(-) create mode 100644 tests/llm/test_llama2.py create mode 100644 tests/llm/test_openai.py diff --git a/embedchain/llm/openai.py b/embedchain/llm/openai.py index 22189b03..f7b8bad4 100644 --- a/embedchain/llm/openai.py +++ b/embedchain/llm/openai.py @@ -13,13 +13,9 @@ class OpenAILlm(BaseLlm): def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config=config) - def get_llm_model_answer(self, prompt): + def get_llm_model_answer(self, prompt) -> str: response = OpenAILlm._get_answer(prompt, self.config) - - if self.config.stream: - return response - else: - return response.content + return response def _get_answer(prompt: str, config: BaseLlmConfig) -> str: messages = [] @@ -35,10 +31,9 @@ class OpenAILlm(BaseLlm): if config.top_p: kwargs["model_kwargs"]["top_p"] = config.top_p if config.stream: - from langchain.callbacks.streaming_stdout import \ - StreamingStdOutCallbackHandler + from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) else: chat = ChatOpenAI(**kwargs) - return chat(messages) + return chat(messages).content diff --git a/pyproject.toml b/pyproject.toml index 268b6d0b..9d60c5a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,7 @@ isort = "^5.12.0" pytest-cov = "^4.1.0" responses = "^0.23.3" mock = "^5.1.0" +pytest-asyncio = "^0.21.1" [tool.poetry.extras] streamlit = ["streamlit"] diff --git a/tests/apps/test_apps.py b/tests/apps/test_apps.py index 0bafce86..d7744c07 100644 --- a/tests/apps/test_apps.py +++ b/tests/apps/test_apps.py @@ -104,6 +104,19 @@ class TestConfigForAppComponents: assert isinstance(embedder_config, BaseEmbedderConfig) + def test_components_raises_type_error_if_not_proper_instances(self): + wrong_llm = "wrong_llm" + with pytest.raises(TypeError): + App(llm=wrong_llm) + + wrong_db = "wrong_db" + with pytest.raises(TypeError): + App(db=wrong_db) + + wrong_embedder = "wrong_embedder" + with pytest.raises(TypeError): + App(embedder=wrong_embedder) + class TestAppFromConfig: def load_config_data(self, yaml_path): diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py index 9112053d..c740e91a 100644 --- a/tests/llm/test_base_llm.py +++ b/tests/llm/test_base_llm.py @@ -1,5 +1,5 @@ import pytest - +from string import Template from embedchain.llm.base import BaseLlm, BaseLlmConfig @@ -14,6 +14,18 @@ def test_is_get_llm_model_answer_not_implemented(base_llm): base_llm.get_llm_model_answer() +def test_is_stream_bool(): + with pytest.raises(ValueError): + config = BaseLlmConfig(stream="test value") + BaseLlm(config=config) + + +def test_template_string_gets_converted_to_Template_instance(): + config = BaseLlmConfig(template="test value $query $context") + llm = BaseLlm(config=config) + assert isinstance(llm.config.template, Template) + + def test_is_get_llm_model_answer_implemented(): class TestLlm(BaseLlm): def get_llm_model_answer(self): diff --git a/tests/llm/test_cohere.py b/tests/llm/test_cohere.py index c7372447..5d1a625d 100644 --- a/tests/llm/test_cohere.py +++ b/tests/llm/test_cohere.py @@ -1,33 +1,55 @@ import os -import unittest -from unittest.mock import patch +import pytest from embedchain.config import BaseLlmConfig from embedchain.llm.cohere import CohereLlm -class TestCohereLlm(unittest.TestCase): - def setUp(self): - os.environ["COHERE_API_KEY"] = "test_api_key" - self.config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8) +@pytest.fixture +def cohere_llm_config(): + os.environ["COHERE_API_KEY"] = "test_api_key" + config = BaseLlmConfig(model="gptd-instruct-tft", max_tokens=50, temperature=0.7, top_p=0.8) + yield config + os.environ.pop("COHERE_API_KEY") - def test_init_raises_value_error_without_api_key(self): - os.environ.pop("COHERE_API_KEY") - with self.assertRaises(ValueError): - CohereLlm() - def test_get_llm_model_answer_raises_value_error_for_system_prompt(self): - llm = CohereLlm(self.config) - llm.config.system_prompt = "system_prompt" - with self.assertRaises(ValueError): - llm.get_llm_model_answer("prompt") +def test_init_raises_value_error_without_api_key(mocker): + mocker.patch.dict(os.environ, clear=True) + with pytest.raises(ValueError): + CohereLlm() - @patch("embedchain.llm.cohere.CohereLlm._get_answer") - def test_get_llm_model_answer(self, mock_get_answer): - mock_get_answer.return_value = "Test answer" - llm = CohereLlm(self.config) - answer = llm.get_llm_model_answer("Test query") +def test_get_llm_model_answer_raises_value_error_for_system_prompt(cohere_llm_config): + llm = CohereLlm(cohere_llm_config) + llm.config.system_prompt = "system_prompt" + with pytest.raises(ValueError): + llm.get_llm_model_answer("prompt") - self.assertEqual(answer, "Test answer") - mock_get_answer.assert_called_once() + +def test_get_llm_model_answer(cohere_llm_config, mocker): + mocker.patch("embedchain.llm.cohere.CohereLlm._get_answer", return_value="Test answer") + + llm = CohereLlm(cohere_llm_config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + + +def test_get_answer_mocked_cohere(cohere_llm_config, mocker): + mocked_cohere = mocker.patch("embedchain.llm.cohere.Cohere") + mock_instance = mocked_cohere.return_value + mock_instance.return_value = "Mocked answer" + + llm = CohereLlm(cohere_llm_config) + prompt = "Test query" + answer = llm.get_llm_model_answer(prompt) + + assert answer == "Mocked answer" + mocked_cohere.assert_called_once_with( + cohere_api_key="test_api_key", + model="gptd-instruct-tft", + max_tokens=50, + temperature=0.7, + p=0.8, + ) + mock_instance.assert_called_once_with(prompt) diff --git a/tests/llm/test_huggingface.py b/tests/llm/test_huggingface.py index cc1c5d9c..a8a7a646 100644 --- a/tests/llm/test_huggingface.py +++ b/tests/llm/test_huggingface.py @@ -1,64 +1,61 @@ import importlib import os -import unittest -from unittest.mock import MagicMock, patch - +import pytest from embedchain.config import BaseLlmConfig from embedchain.llm.huggingface import HuggingFaceLlm -class TestHuggingFaceLlm(unittest.TestCase): - def setUp(self): - os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token" - self.config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8) +@pytest.fixture +def huggingface_llm_config(): + os.environ["HUGGINGFACE_ACCESS_TOKEN"] = "test_access_token" + config = BaseLlmConfig(model="google/flan-t5-xxl", max_tokens=50, temperature=0.7, top_p=0.8) + yield config + os.environ.pop("HUGGINGFACE_ACCESS_TOKEN") - def test_init_raises_value_error_without_api_key(self): - os.environ.pop("HUGGINGFACE_ACCESS_TOKEN") - with self.assertRaises(ValueError): - HuggingFaceLlm() - def test_get_llm_model_answer_raises_value_error_for_system_prompt(self): - llm = HuggingFaceLlm(self.config) - llm.config.system_prompt = "system_prompt" - with self.assertRaises(ValueError): - llm.get_llm_model_answer("prompt") +def test_init_raises_value_error_without_api_key(mocker): + mocker.patch.dict(os.environ, clear=True) + with pytest.raises(ValueError): + HuggingFaceLlm() - def test_top_p_value_within_range(self): - config = BaseLlmConfig(top_p=1.0) - with self.assertRaises(ValueError): - HuggingFaceLlm._get_answer("test_prompt", config) - def test_dependency_is_imported(self): - importlib_installed = True - try: - importlib.import_module("huggingface_hub") - except ImportError: - importlib_installed = False - self.assertTrue(importlib_installed) +def test_get_llm_model_answer_raises_value_error_for_system_prompt(huggingface_llm_config): + llm = HuggingFaceLlm(huggingface_llm_config) + llm.config.system_prompt = "system_prompt" + with pytest.raises(ValueError): + llm.get_llm_model_answer("prompt") - @patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer") - def test_get_llm_model_answer(self, mock_get_answer): - mock_get_answer.return_value = "Test answer" - llm = HuggingFaceLlm(self.config) - answer = llm.get_llm_model_answer("Test query") +def test_top_p_value_within_range(): + config = BaseLlmConfig(top_p=1.0) + with pytest.raises(ValueError): + HuggingFaceLlm._get_answer("test_prompt", config) - self.assertEqual(answer, "Test answer") - mock_get_answer.assert_called_once() - @patch("embedchain.llm.huggingface.HuggingFaceHub") - def test_hugging_face_mock(self, mock_huggingface): - mock_llm_instance = MagicMock() - mock_llm_instance.return_value = "Test answer" - mock_huggingface.return_value = mock_llm_instance +def test_dependency_is_imported(): + importlib_installed = True + try: + importlib.import_module("huggingface_hub") + except ImportError: + importlib_installed = False + assert importlib_installed - llm = HuggingFaceLlm(self.config) - answer = llm.get_llm_model_answer("Test query") - self.assertEqual(answer, "Test answer") - mock_huggingface.assert_called_once_with( - huggingfacehub_api_token="test_access_token", - repo_id="google/flan-t5-xxl", - model_kwargs={"temperature": 0.7, "max_new_tokens": 50, "top_p": 0.8}, - ) - mock_llm_instance.assert_called_once_with("Test query") +def test_get_llm_model_answer(huggingface_llm_config, mocker): + mocker.patch("embedchain.llm.huggingface.HuggingFaceLlm._get_answer", return_value="Test answer") + + llm = HuggingFaceLlm(huggingface_llm_config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + + +def test_hugging_face_mock(huggingface_llm_config, mocker): + mock_llm_instance = mocker.Mock(return_value="Test answer") + mocker.patch("embedchain.llm.huggingface.HuggingFaceHub", return_value=mock_llm_instance) + + llm = HuggingFaceLlm(huggingface_llm_config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + mock_llm_instance.assert_called_once_with("Test query") diff --git a/tests/llm/test_jina.py b/tests/llm/test_jina.py index 49793d57..9ca3f647 100644 --- a/tests/llm/test_jina.py +++ b/tests/llm/test_jina.py @@ -1,40 +1,76 @@ import os -import unittest -from unittest.mock import patch - +import pytest from embedchain.config import BaseLlmConfig from embedchain.llm.jina import JinaLlm +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -class TestJinaLlm(unittest.TestCase): - def setUp(self): - os.environ["JINACHAT_API_KEY"] = "test_api_key" - self.config = BaseLlmConfig( - temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt" - ) +@pytest.fixture +def config(): + os.environ["JINACHAT_API_KEY"] = "test_api_key" + config = BaseLlmConfig(temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt") + yield config + os.environ.pop("JINACHAT_API_KEY") - def test_init_raises_value_error_without_api_key(self): - os.environ.pop("JINACHAT_API_KEY") - with self.assertRaises(ValueError): - JinaLlm() - @patch("embedchain.llm.jina.JinaLlm._get_answer") - def test_get_llm_model_answer(self, mock_get_answer): - mock_get_answer.return_value = "Test answer" +def test_init_raises_value_error_without_api_key(mocker): + mocker.patch.dict(os.environ, clear=True) + with pytest.raises(ValueError): + JinaLlm() - llm = JinaLlm(self.config) - answer = llm.get_llm_model_answer("Test query") - self.assertEqual(answer, "Test answer") - mock_get_answer.assert_called_once() +def test_get_llm_model_answer(config, mocker): + mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer") - @patch("embedchain.llm.jina.JinaLlm._get_answer") - def test_get_llm_model_answer_with_system_prompt(self, mock_get_answer): - self.config.system_prompt = "Custom system prompt" - mock_get_answer.return_value = "Test answer" + llm = JinaLlm(config) + answer = llm.get_llm_model_answer("Test query") - llm = JinaLlm(self.config) - answer = llm.get_llm_model_answer("Test query") + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("Test query", config) - self.assertEqual(answer, "Test answer") - mock_get_answer.assert_called_once() + +def test_get_llm_model_answer_with_system_prompt(config, mocker): + config.system_prompt = "Custom system prompt" + mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer") + + llm = JinaLlm(config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("Test query", config) + + +def test_get_llm_model_answer_empty_prompt(config, mocker): + mocked_get_answer = mocker.patch("embedchain.llm.jina.JinaLlm._get_answer", return_value="Test answer") + + llm = JinaLlm(config) + answer = llm.get_llm_model_answer("") + + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("", config) + + +def test_get_llm_model_answer_with_streaming(config, mocker): + config.stream = True + mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat") + + llm = JinaLlm(config) + llm.get_llm_model_answer("Test query") + + mocked_jinachat.assert_called_once() + callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list] + assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks) + + +def test_get_llm_model_answer_without_system_prompt(config, mocker): + config.system_prompt = None + mocked_jinachat = mocker.patch("embedchain.llm.jina.JinaChat") + + llm = JinaLlm(config) + llm.get_llm_model_answer("Test query") + + mocked_jinachat.assert_called_once_with( + temperature=config.temperature, + max_tokens=config.max_tokens, + model_kwargs={"top_p": config.top_p}, + ) diff --git a/tests/llm/test_llama2.py b/tests/llm/test_llama2.py new file mode 100644 index 00000000..688149b1 --- /dev/null +++ b/tests/llm/test_llama2.py @@ -0,0 +1,47 @@ +import os +import pytest +from embedchain.llm.llama2 import Llama2Llm + + +@pytest.fixture +def llama2_llm(): + os.environ["REPLICATE_API_TOKEN"] = "test_api_token" + llm = Llama2Llm() + return llm + + +def test_init_raises_value_error_without_api_key(mocker): + mocker.patch.dict(os.environ, clear=True) + with pytest.raises(ValueError): + Llama2Llm() + + +def test_get_llm_model_answer_raises_value_error_for_system_prompt(llama2_llm): + llama2_llm.config.system_prompt = "system_prompt" + with pytest.raises(ValueError): + llama2_llm.get_llm_model_answer("prompt") + + +def test_get_llm_model_answer(llama2_llm, mocker): + mocked_replicate = mocker.patch("embedchain.llm.llama2.Replicate") + mocked_replicate_instance = mocker.MagicMock() + mocked_replicate.return_value = mocked_replicate_instance + mocked_replicate_instance.return_value = "Test answer" + + llama2_llm.config.model = "test_model" + llama2_llm.config.max_tokens = 50 + llama2_llm.config.temperature = 0.7 + llama2_llm.config.top_p = 0.8 + + answer = llama2_llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + mocked_replicate.assert_called_once_with( + model="test_model", + input={ + "temperature": 0.7, + "max_length": 50, + "top_p": 0.8, + }, + ) + mocked_replicate_instance.assert_called_once_with("Test query") diff --git a/tests/llm/test_openai.py b/tests/llm/test_openai.py new file mode 100644 index 00000000..a1795a6c --- /dev/null +++ b/tests/llm/test_openai.py @@ -0,0 +1,73 @@ +import os +import pytest +from embedchain.config import BaseLlmConfig +from embedchain.llm.openai import OpenAILlm +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + + +@pytest.fixture +def config(): + os.environ["OPENAI_API_KEY"] = "test_api_key" + config = BaseLlmConfig( + temperature=0.7, max_tokens=50, top_p=0.8, stream=False, system_prompt="System prompt", model="gpt-3.5-turbo" + ) + yield config + os.environ.pop("OPENAI_API_KEY") + + +def test_get_llm_model_answer(config, mocker): + mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer") + + llm = OpenAILlm(config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("Test query", config) + + +def test_get_llm_model_answer_with_system_prompt(config, mocker): + config.system_prompt = "Custom system prompt" + mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer") + + llm = OpenAILlm(config) + answer = llm.get_llm_model_answer("Test query") + + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("Test query", config) + + +def test_get_llm_model_answer_empty_prompt(config, mocker): + mocked_get_answer = mocker.patch("embedchain.llm.openai.OpenAILlm._get_answer", return_value="Test answer") + + llm = OpenAILlm(config) + answer = llm.get_llm_model_answer("") + + assert answer == "Test answer" + mocked_get_answer.assert_called_once_with("", config) + + +def test_get_llm_model_answer_with_streaming(config, mocker): + config.stream = True + mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI") + + llm = OpenAILlm(config) + llm.get_llm_model_answer("Test query") + + mocked_jinachat.assert_called_once() + callbacks = [callback[1]["callbacks"] for callback in mocked_jinachat.call_args_list] + assert any(isinstance(callback[0], StreamingStdOutCallbackHandler) for callback in callbacks) + + +def test_get_llm_model_answer_without_system_prompt(config, mocker): + config.system_prompt = None + mocked_jinachat = mocker.patch("embedchain.llm.openai.ChatOpenAI") + + llm = OpenAILlm(config) + llm.get_llm_model_answer("Test query") + + mocked_jinachat.assert_called_once_with( + model=config.model, + temperature=config.temperature, + max_tokens=config.max_tokens, + model_kwargs={"top_p": config.top_p}, + ) diff --git a/tests/llm/test_query.py b/tests/llm/test_query.py index e3b11afd..b208e00c 100644 --- a/tests/llm/test_query.py +++ b/tests/llm/test_query.py @@ -1,135 +1,85 @@ import os -import unittest +import pytest from unittest.mock import MagicMock, patch - from embedchain import App from embedchain.config import AppConfig, BaseLlmConfig -class TestApp(unittest.TestCase): - os.environ["OPENAI_API_KEY"] = "test_key" +@pytest.fixture +def app(): + os.environ["OPENAI_API_KEY"] = "test_api_key" + app = App(config=AppConfig(collect_metrics=False)) + return app - def setUp(self): - self.app = App(config=AppConfig(collect_metrics=False)) - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_query(self): - """ - This test checks the functionality of the 'query' method in the App class. - It simulates a scenario where the 'retrieve_from_database' method returns a context list and - 'get_llm_model_answer' returns an expected answer string. - - The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods - appropriately and return the right answer. - - Key assumptions tested: - - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of - BaseLlmConfig. - - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. - - 'query' method returns the value it received from 'get_llm_model_answer'. - - The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and - 'get_llm_model_answer' methods. - """ - with patch.object(self.app, "retrieve_from_database") as mock_retrieve: - mock_retrieve.return_value = ["Test context"] - with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: - mock_answer.return_value = "Test answer" - _answer = self.app.query(input_query="Test query") - - # Ensure retrieve_from_database was called - mock_retrieve.assert_called_once() - - # Check the call arguments - args, kwargs = mock_retrieve.call_args - input_query_arg = kwargs.get("input_query") - self.assertEqual(input_query_arg, "Test query") - mock_answer.assert_called_once() - - @patch("embedchain.llm.openai.OpenAILlm._get_answer") - def test_query_config_app_passing(self, mock_get_answer): - mock_get_answer.return_value = MagicMock() - mock_get_answer.return_value.content = "Test answer" - - config = AppConfig(collect_metrics=False) - chat_config = BaseLlmConfig(system_prompt="Test system prompt") - app = App(config=config, llm_config=chat_config) - answer = app.llm.get_llm_model_answer("Test query") - - self.assertEqual(app.llm.config.system_prompt, "Test system prompt") - self.assertEqual(answer, "Test answer") - - @patch("embedchain.llm.openai.OpenAILlm._get_answer") - def test_app_passing(self, mock_get_answer): - mock_get_answer.return_value = MagicMock() - mock_get_answer.return_value.content = "Test answer" - config = AppConfig(collect_metrics=False) - chat_config = BaseLlmConfig() - app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt") - answer = app.llm.get_llm_model_answer("Test query") - self.assertEqual(app.llm.config.system_prompt, "Test system prompt") - self.assertEqual(answer, "Test answer") - - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_query_with_where_in_params(self): - """ - This test checks the functionality of the 'query' method in the App class. - It simulates a scenario where the 'retrieve_from_database' method returns a context list based on - a where filter and 'get_llm_model_answer' returns an expected answer string. - - The 'query' method is expected to call 'retrieve_from_database' with the where filter and - 'get_llm_model_answer' methods appropriately and return the right answer. - - Key assumptions tested: - - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of - BaseLlmConfig. - - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. - - 'query' method returns the value it received from 'get_llm_model_answer'. - - The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and - 'get_llm_model_answer' methods. - """ - with patch.object(self.app, "retrieve_from_database") as mock_retrieve: - mock_retrieve.return_value = ["Test context"] - with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: - mock_answer.return_value = "Test answer" - answer = self.app.query("Test query", where={"attribute": "value"}) - - self.assertEqual(answer, "Test answer") - _args, kwargs = mock_retrieve.call_args - self.assertEqual(kwargs.get("input_query"), "Test query") - self.assertEqual(kwargs.get("where"), {"attribute": "value"}) - mock_answer.assert_called_once() - - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_query_with_where_in_query_config(self): - """ - This test checks the functionality of the 'query' method in the App class. - It simulates a scenario where the 'retrieve_from_database' method returns a context list based on - a where filter and 'get_llm_model_answer' returns an expected answer string. - - The 'query' method is expected to call 'retrieve_from_database' with the where filter and - 'get_llm_model_answer' methods appropriately and return the right answer. - - Key assumptions tested: - - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of - BaseLlmConfig. - - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. - - 'query' method returns the value it received from 'get_llm_model_answer'. - - The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and - 'get_llm_model_answer' methods. - """ - - with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: +@patch("chromadb.api.models.Collection.Collection.add", MagicMock) +def test_query(app): + with patch.object(app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + with patch.object(app.llm, "get_llm_model_answer") as mock_answer: mock_answer.return_value = "Test answer" - with patch.object(self.app.db, "query") as mock_database_query: - mock_database_query.return_value = ["Test context"] - llm_config = BaseLlmConfig(where={"attribute": "value"}) - answer = self.app.query("Test query", llm_config) + answer = app.query(input_query="Test query") + assert answer == "Test answer" - self.assertEqual(answer, "Test answer") - _args, kwargs = mock_database_query.call_args - self.assertEqual(kwargs.get("input_query"), "Test query") - self.assertEqual(kwargs.get("where"), {"attribute": "value"}) - mock_answer.assert_called_once() + mock_retrieve.assert_called_once() + _, kwargs = mock_retrieve.call_args + input_query_arg = kwargs.get("input_query") + assert input_query_arg == "Test query" + mock_answer.assert_called_once() + + +@patch("embedchain.llm.openai.OpenAILlm._get_answer") +def test_query_config_app_passing(mock_get_answer): + mock_get_answer.return_value = MagicMock() + mock_get_answer.return_value = "Test answer" + + config = AppConfig(collect_metrics=False) + chat_config = BaseLlmConfig(system_prompt="Test system prompt") + app = App(config=config, llm_config=chat_config) + answer = app.llm.get_llm_model_answer("Test query") + + assert app.llm.config.system_prompt == "Test system prompt" + assert answer == "Test answer" + + +@patch("embedchain.llm.openai.OpenAILlm._get_answer") +def test_app_passing(mock_get_answer): + mock_get_answer.return_value = MagicMock() + mock_get_answer.return_value = "Test answer" + config = AppConfig(collect_metrics=False) + chat_config = BaseLlmConfig() + app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt") + answer = app.llm.get_llm_model_answer("Test query") + assert app.llm.config.system_prompt == "Test system prompt" + assert answer == "Test answer" + + +@patch("chromadb.api.models.Collection.Collection.add", MagicMock) +def test_query_with_where_in_params(app): + with patch.object(app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + with patch.object(app.llm, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + answer = app.query("Test query", where={"attribute": "value"}) + + assert answer == "Test answer" + _, kwargs = mock_retrieve.call_args + assert kwargs.get("input_query") == "Test query" + assert kwargs.get("where") == {"attribute": "value"} + mock_answer.assert_called_once() + + +@patch("chromadb.api.models.Collection.Collection.add", MagicMock) +def test_query_with_where_in_query_config(app): + with patch.object(app.llm, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + with patch.object(app.db, "query") as mock_database_query: + mock_database_query.return_value = ["Test context"] + llm_config = BaseLlmConfig(where={"attribute": "value"}) + answer = app.query("Test query", llm_config) + + assert answer == "Test answer" + _, kwargs = mock_database_query.call_args + assert kwargs.get("input_query") == "Test query" + assert kwargs.get("where") == {"attribute": "value"} + mock_answer.assert_called_once() diff --git a/tests/models/test_data_type.py b/tests/models/test_data_type.py index 7b2f173e..f0baa588 100644 --- a/tests/models/test_data_type.py +++ b/tests/models/test_data_type.py @@ -1,32 +1,29 @@ -import unittest - -from embedchain.models.data_type import (DataType, DirectDataType, - IndirectDataType, SpecialDataType) +from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType -class TestDataTypeEnums(unittest.TestCase): - def test_subclass_types_in_data_type(self): - """Test that all data type category subclasses are contained in the composite data type""" - # Check if DirectDataType values are in DataType - for data_type in DirectDataType: - self.assertIn(data_type.value, DataType._value2member_map_) +def test_subclass_types_in_data_type(): + """Test that all data type category subclasses are contained in the composite data type""" + # Check if DirectDataType values are in DataType + for data_type in DirectDataType: + assert data_type.value in DataType._value2member_map_ - # Check if IndirectDataType values are in DataType - for data_type in IndirectDataType: - self.assertIn(data_type.value, DataType._value2member_map_) + # Check if IndirectDataType values are in DataType + for data_type in IndirectDataType: + assert data_type.value in DataType._value2member_map_ - # Check if SpecialDataType values are in DataType - for data_type in SpecialDataType: - self.assertIn(data_type.value, DataType._value2member_map_) + # Check if SpecialDataType values are in DataType + for data_type in SpecialDataType: + assert data_type.value in DataType._value2member_map_ - def test_data_type_in_subclasses(self): - """Test that all data types in the composite data type are categorized in a subclass""" - for data_type in DataType: - if data_type.value in DirectDataType._value2member_map_: - self.assertIn(data_type.value, DirectDataType._value2member_map_) - elif data_type.value in IndirectDataType._value2member_map_: - self.assertIn(data_type.value, IndirectDataType._value2member_map_) - elif data_type.value in SpecialDataType._value2member_map_: - self.assertIn(data_type.value, SpecialDataType._value2member_map_) - else: - self.fail(f"{data_type.value} not found in any subclass enums") + +def test_data_type_in_subclasses(): + """Test that all data types in the composite data type are categorized in a subclass""" + for data_type in DataType: + if data_type.value in DirectDataType._value2member_map_: + assert data_type.value in DirectDataType._value2member_map_ + elif data_type.value in IndirectDataType._value2member_map_: + assert data_type.value in IndirectDataType._value2member_map_ + elif data_type.value in SpecialDataType._value2member_map_: + assert data_type.value in SpecialDataType._value2member_map_ + else: + assert False, f"{data_type.value} not found in any subclass enums"