From 5ec12212e4c5741505c62ffac27659fac5dc9a82 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Sun, 15 Oct 2023 07:46:27 +0530 Subject: [PATCH] Improve tests (#800) --- embedchain/apps/person_app.py | 14 ++-- tests/apps/test_apps.py | 132 +++++++++++++++++----------------- tests/apps/test_person_app.py | 80 +++++++++++++++++++++ tests/bots/test_poe.py | 50 +++++++++++++ tests/embedchain/test_add.py | 89 +++++++++-------------- tests/llm/test_base_llm.py | 52 ++++++++++++++ 6 files changed, 287 insertions(+), 130 deletions(-) create mode 100644 tests/apps/test_person_app.py create mode 100644 tests/bots/test_poe.py create mode 100644 tests/llm/test_base_llm.py diff --git a/embedchain/apps/person_app.py b/embedchain/apps/person_app.py index 02940def..4511e4a6 100644 --- a/embedchain/apps/person_app.py +++ b/embedchain/apps/person_app.py @@ -2,10 +2,8 @@ from string import Template from embedchain.apps.app import App from embedchain.apps.open_source_app import OpenSourceApp -from embedchain.config import BaseLlmConfig -from embedchain.config.apps.base_app_config import BaseAppConfig -from embedchain.config.llm.base import (DEFAULT_PROMPT, - DEFAULT_PROMPT_WITH_HISTORY) +from embedchain.config import BaseLlmConfig, AppConfig +from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY from embedchain.helper.json_serializable import register_deserializable @@ -16,16 +14,16 @@ class EmbedChainPersonApp: This bot behaves and speaks like a person. :param person: name of the person, better if its a well known person. - :param config: BaseAppConfig instance to load as configuration. + :param config: AppConfig instance to load as configuration. """ - def __init__(self, person: str, config: BaseAppConfig = None): + def __init__(self, person: str, config: AppConfig = None): """Initialize a new person app :param person: Name of the person that's imitated. :type person: str :param config: Configuration class instance, defaults to None - :type config: BaseAppConfig, optional + :type config: AppConfig, optional """ self.person = person self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501 @@ -70,7 +68,7 @@ class PersonApp(EmbedChainPersonApp, App): """ def query(self, input_query, config: BaseLlmConfig = None, dry_run=False): - config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None) + config = self.add_person_template_to_config(DEFAULT_PROMPT, config) return super().query(input_query, config, dry_run, where=None) def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None): diff --git a/tests/apps/test_apps.py b/tests/apps/test_apps.py index c754982a..2e311242 100644 --- a/tests/apps/test_apps.py +++ b/tests/apps/test_apps.py @@ -1,108 +1,106 @@ import os -import unittest - +import pytest import yaml from embedchain import App, CustomApp, Llama2App, OpenSourceApp -from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig, - BaseLlmConfig, ChromaDbConfig) +from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig from embedchain.embedder.base import BaseEmbedder from embedchain.llm.base import BaseLlm from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig from embedchain.vectordb.chroma import ChromaDB -class TestApps(unittest.TestCase): +@pytest.fixture +def app(): os.environ["OPENAI_API_KEY"] = "test_api_key" - - def test_app(self): - app = App() - self.assertIsInstance(app.llm, BaseLlm) - self.assertIsInstance(app.db, BaseVectorDB) - self.assertIsInstance(app.embedder, BaseEmbedder) - - wrong_llm = "wrong_llm" - with self.assertRaises(TypeError): - App(llm=wrong_llm) - - wrong_db = "wrong_db" - with self.assertRaises(TypeError): - App(db=wrong_db) - - wrong_embedder = "wrong_embedder" - with self.assertRaises(TypeError): - App(embedder=wrong_embedder) - - def test_custom_app(self): - app = CustomApp() - self.assertIsInstance(app.llm, BaseLlm) - self.assertIsInstance(app.db, BaseVectorDB) - self.assertIsInstance(app.embedder, BaseEmbedder) - - def test_opensource_app(self): - app = OpenSourceApp() - self.assertIsInstance(app.llm, BaseLlm) - self.assertIsInstance(app.db, BaseVectorDB) - self.assertIsInstance(app.embedder, BaseEmbedder) - - def test_llama2_app(self): - os.environ["REPLICATE_API_TOKEN"] = "-" - app = Llama2App() - self.assertIsInstance(app.llm, BaseLlm) - self.assertIsInstance(app.db, BaseVectorDB) - self.assertIsInstance(app.embedder, BaseEmbedder) + return App() -class TestConfigForAppComponents(unittest.TestCase): - collection_name = "my-test-collection" +@pytest.fixture +def custom_app(): + os.environ["OPENAI_API_KEY"] = "test_api_key" + return CustomApp() + +@pytest.fixture +def opensource_app(): + os.environ["OPENAI_API_KEY"] = "test_api_key" + return OpenSourceApp() + + +@pytest.fixture +def llama2_app(): + os.environ["OPENAI_API_KEY"] = "test_api_key" + os.environ["REPLICATE_API_TOKEN"] = "-" + return Llama2App() + + +def test_app(app): + assert isinstance(app.llm, BaseLlm) + assert isinstance(app.db, BaseVectorDB) + assert isinstance(app.embedder, BaseEmbedder) + + +def test_custom_app(custom_app): + assert isinstance(custom_app.llm, BaseLlm) + assert isinstance(custom_app.db, BaseVectorDB) + assert isinstance(custom_app.embedder, BaseEmbedder) + + +def test_opensource_app(opensource_app): + assert isinstance(opensource_app.llm, BaseLlm) + assert isinstance(opensource_app.db, BaseVectorDB) + assert isinstance(opensource_app.embedder, BaseEmbedder) + + +def test_llama2_app(llama2_app): + assert isinstance(llama2_app.llm, BaseLlm) + assert isinstance(llama2_app.db, BaseVectorDB) + assert isinstance(llama2_app.embedder, BaseEmbedder) + + +class TestConfigForAppComponents: def test_constructor_config(self): - """ - Test that app can be configured through the app constructor. - """ - app = App(db_config=ChromaDbConfig(collection_name=self.collection_name)) - self.assertEqual(app.db.config.collection_name, self.collection_name) + collection_name = "my-test-collection" + app = App(db_config=ChromaDbConfig(collection_name=collection_name)) + assert app.db.config.collection_name == collection_name def test_component_config(self): - """ - Test that app can also be configured by passing a configured component to the app - """ - database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name)) + collection_name = "my-test-collection" + database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name)) app = App(db=database) - self.assertEqual(app.db.config.collection_name, self.collection_name) + assert app.db.config.collection_name == collection_name def test_different_configs_are_proper_instances(self): - config = AppConfig() - wrong_app_config = AddConfig() + app_config = AppConfig() + wrong_config = AddConfig() + with pytest.raises(TypeError): + App(config=wrong_config) - with self.assertRaises(TypeError): - App(config=wrong_app_config) - - self.assertIsInstance(config, AppConfig) + assert isinstance(app_config, AppConfig) llm_config = BaseLlmConfig() wrong_llm_config = "wrong_llm_config" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): App(llm_config=wrong_llm_config) - self.assertIsInstance(llm_config, BaseLlmConfig) + assert isinstance(llm_config, BaseLlmConfig) db_config = BaseVectorDbConfig() wrong_db_config = "wrong_db_config" - with self.assertRaises(TypeError): + with pytest.raises(TypeError): App(db_config=wrong_db_config) - self.assertIsInstance(db_config, BaseVectorDbConfig) + assert isinstance(db_config, BaseVectorDbConfig) embedder_config = BaseEmbedderConfig() wrong_embedder_config = "wrong_embedder_config" - - with self.assertRaises(TypeError): + with pytest.raises(TypeError): App(embedder_config=wrong_embedder_config) - self.assertIsInstance(embedder_config, BaseEmbedderConfig) + assert isinstance(embedder_config, BaseEmbedderConfig) class TestAppFromConfig: diff --git a/tests/apps/test_person_app.py b/tests/apps/test_person_app.py new file mode 100644 index 00000000..dc846508 --- /dev/null +++ b/tests/apps/test_person_app.py @@ -0,0 +1,80 @@ +import pytest +from embedchain.apps.app import App +from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp +from embedchain.config import BaseLlmConfig, AppConfig +from embedchain.config.llm.base import DEFAULT_PROMPT + + +@pytest.fixture +def person_app(): + config = AppConfig() + return PersonApp("John Doe", config) + + +@pytest.fixture +def opensource_person_app(): + config = AppConfig() + return PersonOpenSourceApp("John Doe", config) + + +def test_person_app_initialization(person_app): + assert person_app.person == "John Doe" + assert f"You are {person_app.person}" in person_app.person_prompt + assert isinstance(person_app.config, AppConfig) + + +def test_person_app_add_person_template_to_config_with_invalid_template(): + app = PersonApp("John Doe") + default_prompt = "Input Prompt" + with pytest.raises(ValueError): + # as prompt doesn't contain $context and $query + app.add_person_template_to_config(default_prompt) + + +def test_person_app_add_person_template_to_config_with_valid_template(): + app = PersonApp("John Doe") + config = app.add_person_template_to_config(DEFAULT_PROMPT) + assert ( + config.template.template + == f"You are John Doe. Whatever you say, you will always say in John Doe style. {DEFAULT_PROMPT}" + ) + + +def test_person_app_query(mocker, person_app): + input_query = "Hello, how are you?" + config = BaseLlmConfig() + + mocker.patch.object(App, "query", return_value="Mocked response") + + result = person_app.query(input_query, config) + assert result == "Mocked response" + + +def test_person_app_chat(mocker, person_app): + input_query = "Hello, how are you?" + config = BaseLlmConfig() + + mocker.patch.object(App, "chat", return_value="Mocked chat response") + + result = person_app.chat(input_query, config) + assert result == "Mocked chat response" + + +def test_opensource_person_app_query(mocker, opensource_person_app): + input_query = "Hello, how are you?" + config = BaseLlmConfig() + + mocker.patch.object(App, "query", return_value="Mocked response") + + result = opensource_person_app.query(input_query, config) + assert result == "Mocked response" + + +def test_opensource_person_app_chat(mocker, opensource_person_app): + input_query = "Hello, how are you?" + config = BaseLlmConfig() + + mocker.patch.object(App, "chat", return_value="Mocked chat response") + + result = opensource_person_app.chat(input_query, config) + assert result == "Mocked chat response" diff --git a/tests/bots/test_poe.py b/tests/bots/test_poe.py new file mode 100644 index 00000000..09ae1d6d --- /dev/null +++ b/tests/bots/test_poe.py @@ -0,0 +1,50 @@ +import argparse +import pytest + +from embedchain.bots.poe import PoeBot, start_command +from fastapi_poe.types import QueryRequest, ProtocolMessage + + +@pytest.fixture +def poe_bot(mocker): + bot = PoeBot() + mocker.patch("fastapi_poe.run") + return bot + + +@pytest.mark.asyncio +async def test_poe_bot_get_response(poe_bot, mocker): + query = QueryRequest( + version="test", + type="query", + query=[ProtocolMessage(role="system", content="Test content")], + user_id="test_user_id", + conversation_id="test_conversation_id", + message_id="test_message_id", + ) + + mocker.patch.object(poe_bot.app.llm, "set_history") + + response_generator = poe_bot.get_response(query) + + await response_generator.__anext__() + poe_bot.app.llm.set_history.assert_called_once() + + +def test_poe_bot_handle_message(poe_bot, mocker): + mocker.patch.object(poe_bot, "ask_bot", return_value="Answer from the bot") + + response_ask = poe_bot.handle_message("What is the answer?") + assert response_ask == "Answer from the bot" + + # TODO: This test will fail because the add_data method is commented out. + # mocker.patch.object(poe_bot, 'add_data', return_value="Added data from: some_data") + # response_add = poe_bot.handle_message("/add some_data") + # assert response_add == "Added data from: some_data" + + +def test_start_command(mocker): + mocker.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(api_key="test_api_key")) + mocker.patch("embedchain.bots.poe.run") + + start_command() diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py index f63f60d4..7974b9c4 100644 --- a/tests/embedchain/test_add.py +++ b/tests/embedchain/test_add.py @@ -1,70 +1,49 @@ import os -import unittest -from unittest.mock import MagicMock, patch - +import pytest from embedchain import App from embedchain.config import AddConfig, AppConfig, ChunkerConfig from embedchain.models.data_type import DataType +os.environ["OPENAI_API_KEY"] = "test_key" -class TestApp(unittest.TestCase): - os.environ["OPENAI_API_KEY"] = "test_key" - def setUp(self): - self.app = App(config=AppConfig(collect_metrics=False)) +@pytest.fixture +def app(mocker): + mocker.patch("chromadb.api.models.Collection.Collection.add") + return App(config=AppConfig(collect_metrics=False)) - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_add(self): - """ - This test checks the functionality of the 'add' method in the App class. - It begins by simulating the addition of a web page with a specific URL to the application instance. - The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance. - By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the - method is working as intended. - The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the - 'add' method. - """ - self.app.add("https://example.com", metadata={"meta": "meta-data"}) - self.assertEqual(self.app.user_asks, [["https://example.com", "web_page", {"meta": "meta-data"}]]) - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_add_sitemap(self): - """ - In addition to the test_add function, this test checks that sitemaps can be added with the correct data type. - """ - self.app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"}) - self.assertEqual(self.app.user_asks, [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]) +def test_add(app): + app.add("https://example.com", metadata={"meta": "meta-data"}) + assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]] - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_add_forced_type(self): - """ - Test that you can also force a data_type with `add`. - """ - data_type = "text" - self.app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"}) - self.assertEqual(self.app.user_asks, [["https://example.com", data_type, {"meta": "meta-data"}]]) - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_dry_run(self): - """ - Test that if dry_run == True then data chunks are returned. - """ +def test_add_sitemap(app): + app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"}) + assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]] - chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0) - # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters. - text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ""" - result = self.app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True) +def test_add_forced_type(app): + data_type = "text" + app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"}) + assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]] - chunks = result["chunks"] - metadata = result["metadata"] - count = result["count"] - data_type = result["type"] - self.assertEqual(len(chunks), len(text)) - self.assertEqual(count, len(text)) - self.assertEqual(data_type, DataType.TEXT) - for item in metadata: - self.assertIsInstance(item, dict) - self.assertIn(item["url"], "local") - self.assertIn(item["data_type"], "text") +def test_dry_run(app): + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0) + text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ""" + + result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True) + + chunks = result["chunks"] + metadata = result["metadata"] + count = result["count"] + data_type = result["type"] + + assert len(chunks) == len(text) + assert count == len(text) + assert data_type == DataType.TEXT + for item in metadata: + assert isinstance(item, dict) + assert "local" in item["url"] + assert "text" in item["data_type"] diff --git a/tests/llm/test_base_llm.py b/tests/llm/test_base_llm.py new file mode 100644 index 00000000..e74f8a0c --- /dev/null +++ b/tests/llm/test_base_llm.py @@ -0,0 +1,52 @@ +import pytest +from embedchain.llm.base import BaseLlm, BaseLlmConfig + + +@pytest.fixture +def base_llm(): + config = BaseLlmConfig() + return BaseLlm(config=config) + + +def test_is_get_llm_model_answer_not_implemented(base_llm): + with pytest.raises(NotImplementedError): + base_llm.get_llm_model_answer() + + +def test_is_get_llm_model_answer_implemented(): + class TestLlm(BaseLlm): + def get_llm_model_answer(self): + return "Implemented" + + config = BaseLlmConfig() + llm = TestLlm(config=config) + assert llm.get_llm_model_answer() == "Implemented" + + +def test_stream_query_response(base_llm): + answer = ["Chunk1", "Chunk2", "Chunk3"] + result = list(base_llm._stream_query_response(answer)) + assert result == answer + + +def test_stream_chat_response(base_llm): + answer = ["Chunk1", "Chunk2", "Chunk3"] + result = list(base_llm._stream_chat_response(answer)) + assert result == answer + + +def test_append_search_and_context(base_llm): + context = "Context" + web_search_result = "Web Search Result" + result = base_llm._append_search_and_context(context, web_search_result) + expected_result = "Context\nWeb Search Result: Web Search Result" + assert result == expected_result + + +def test_access_search_and_get_results(base_llm, mocker): + base_llm.access_search_and_get_results = mocker.patch.object( + base_llm, "access_search_and_get_results", return_value="Search Results" + ) + input_query = "Test query" + result = base_llm.access_search_and_get_results(input_query) + assert result == "Search Results"