diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py new file mode 100644 index 00000000..e2aae842 --- /dev/null +++ b/tests/chunkers/test_text.py @@ -0,0 +1,42 @@ +# ruff: noqa: E501 + +import unittest + +from embedchain.chunkers.text import TextChunker + + +class TestTextChunker(unittest.TestCase): + def test_chunks(self): + """ + Test the chunks generated by TextChunker. + # TODO: Not a very precise test. + """ + chunker_config = { + "chunk_size": 10, + "chunk_overlap": 0, + "length_function": len, + } + chunker = TextChunker(config=chunker_config) + text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." + + result = chunker.create_chunks(MockLoader(), text) + + documents = result["documents"] + + self.assertGreaterEqual(len(documents), 5) + + # Additional test cases can be added to cover different scenarios + + +class MockLoader: + def load_data(self, src): + """ + Mock loader that returns a list of data dictionaries. + Adjust this method to return different data for testing. + """ + return [ + { + "content": src, + "meta_data": {"url": "none"}, + } + ] diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py new file mode 100644 index 00000000..9373afb1 --- /dev/null +++ b/tests/embedchain/test_add.py @@ -0,0 +1,26 @@ +import os +import unittest +from unittest.mock import MagicMock, patch + +from embedchain import App + + +class TestApp(unittest.TestCase): + os.environ["OPENAI_API_KEY"] = "test_key" + + def setUp(self): + self.app = App() + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_add(self): + """ + This test checks the functionality of the 'add' method in the App class. + It begins by simulating the addition of a web page with a specific URL to the application instance. + The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance. + By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the + method is working as intended. + The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the + 'add' method. + """ + self.app.add("web_page", "https://example.com", {"meta": "meta-data"}) + self.assertEqual(self.app.user_asks, [["web_page", "https://example.com", {"meta": "meta-data"}]]) diff --git a/tests/embedchain/test_chat.py b/tests/embedchain/test_chat.py new file mode 100644 index 00000000..29032e4d --- /dev/null +++ b/tests/embedchain/test_chat.py @@ -0,0 +1,48 @@ +import os +import unittest +from unittest.mock import patch + +from embedchain import App + + +class TestApp(unittest.TestCase): + os.environ["OPENAI_API_KEY"] = "test_key" + + def setUp(self): + self.app = App() + + @patch("embedchain.embedchain.memory", autospec=True) + @patch.object(App, "retrieve_from_database", return_value=["Test context"]) + @patch.object(App, "get_answer_from_llm", return_value="Test answer") + def test_chat_with_memory(self, mock_answer, mock_retrieve, mock_memory): + """ + This test checks the functionality of the 'chat' method in the App class with respect to the chat history + memory. + The 'chat' method is called twice. The first call initializes the chat history memory. + The second call is expected to use the chat history from the first call. + + Key assumptions tested: + - After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are + called with correct arguments, adding the correct chat history. + - During the second call, the 'chat' method uses the chat history from the first call. + + The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and + 'memory' methods. + """ + mock_memory.load_memory_variables.return_value = {"history": []} + app = App() + + # First call to chat + first_answer = app.chat("Test query 1") + self.assertEqual(first_answer, "Test answer") + mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 1") + mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer") + + mock_memory.chat_memory.add_user_message.reset_mock() + mock_memory.chat_memory.add_ai_message.reset_mock() + + # Second call to chat + second_answer = app.chat("Test query 2") + self.assertEqual(second_answer, "Test answer") + mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 2") + mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer") diff --git a/tests/embedchain/test_dryrun.py b/tests/embedchain/test_dryrun.py new file mode 100644 index 00000000..9be60748 --- /dev/null +++ b/tests/embedchain/test_dryrun.py @@ -0,0 +1,52 @@ +import os +import unittest +from string import Template +from unittest.mock import patch + +from embedchain import App +from embedchain.embedchain import QueryConfig + + +class TestApp(unittest.TestCase): + os.environ["OPENAI_API_KEY"] = "test_key" + + def setUp(self): + self.app = App() + + @patch("logging.info") + def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info): + """ + Test that the 'query' method logs the same prompt as the 'dry_run' method. + This is the only way I found to test the prompt in query, that's not returned. + """ + with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + input_query = "Test query" + config = QueryConfig( + number_documents=3, + template=Template("Question: $query, context: $context, history: $history"), + history=["Past context 1", "Past context 2"], + ) + + with patch.object(self.app, "get_answer_from_llm"): + self.app.dry_run(input_query, config) + self.app.query(input_query, config) + + # Access the log messages captured during the execution + logged_messages = [call[0][0] for call in mock_logging_info.call_args_list] + + # Extract the prompts from the log messages + dry_run_prompt = self.extract_prompt(logged_messages[0]) + query_prompt = self.extract_prompt(logged_messages[1]) + + # Perform assertions on the prompts + self.assertEqual(dry_run_prompt, query_prompt) + + def extract_prompt(self, log_message): + """ + Extracts the prompt value from the log message. + Adjust this method based on the log message format in your implementation. + """ + # Modify this logic based on your log message format + prefix = "Prompt: " + return log_message.split(prefix, 1)[1] diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py new file mode 100644 index 00000000..6e2472ed --- /dev/null +++ b/tests/embedchain/test_embedchain.py @@ -0,0 +1,39 @@ +import os +import unittest +from unittest.mock import patch + +from embedchain import App +from embedchain.config import InitConfig + + +class TestChromaDbHostsLoglevel(unittest.TestCase): + os.environ["OPENAI_API_KEY"] = "test_key" + + @patch("chromadb.api.models.Collection.Collection.add") + @patch("chromadb.api.models.Collection.Collection.get") + @patch("embedchain.embedchain.EmbedChain.retrieve_from_database") + @patch("embedchain.embedchain.EmbedChain.get_answer_from_llm") + @patch("embedchain.embedchain.EmbedChain.get_llm_model_answer") + def test_whole_app( + self, + _mock_get, + _mock_add, + _mock_ec_retrieve_from_database, + _mock_get_answer_from_llm, + mock_ec_get_llm_model_answer, + ): + """ + Test if the `App` instance is initialized without a config that does not contain default hosts and ports. + """ + config = InitConfig(log_level="DEBUG") + + app = App(config) + + knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing" + + app.add_local("text", knowledge) + + app.query("What text did I give you?") + app.chat("What text did I give you?") + + self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge]) diff --git a/tests/embedchain/test_generate_prompt.py b/tests/embedchain/test_generate_prompt.py new file mode 100644 index 00000000..af91635e --- /dev/null +++ b/tests/embedchain/test_generate_prompt.py @@ -0,0 +1,66 @@ +import unittest +from string import Template + +from embedchain import App +from embedchain.embedchain import QueryConfig + + +class TestGeneratePrompt(unittest.TestCase): + def setUp(self): + self.app = App() + + def test_generate_prompt_with_template(self): + """ + Tests that the generate_prompt method correctly formats the prompt using + a custom template provided in the QueryConfig instance. + + This test sets up a scenario with an input query and a list of contexts, + and a custom template, and then calls generate_prompt. It checks that the + returned prompt correctly incorporates all the contexts and the query into + the format specified by the template. + """ + # Setup + input_query = "Test query" + contexts = ["Context 1", "Context 2", "Context 3"] + template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:" + config = QueryConfig(template=Template(template)) + + # Execute + result = self.app.generate_prompt(input_query, contexts, config) + + # Assert + expected_result = ( + "You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:" + ) + self.assertEqual(result, expected_result) + + def test_generate_prompt_with_contexts_list(self): + """ + Tests that the generate_prompt method correctly handles a list of contexts. + + This test sets up a scenario with an input query and a list of contexts, + and then calls generate_prompt. It checks that the returned prompt + correctly includes all the contexts and the query. + """ + # Setup + input_query = "Test query" + contexts = ["Context 1", "Context 2", "Context 3"] + config = QueryConfig() + + # Execute + result = self.app.generate_prompt(input_query, contexts, config) + + # Assert + expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query) + self.assertEqual(result, expected_result) + + def test_generate_prompt_with_history(self): + """ + Test the 'generate_prompt' method with QueryConfig containing a history attribute. + """ + config = QueryConfig(history=["Past context 1", "Past context 2"]) + config.template = Template("Context: $context | Query: $query | History: $history") + prompt = self.app.generate_prompt("Test query", ["Test context"], config) + + expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']" + self.assertEqual(prompt, expected_prompt) diff --git a/tests/embedchain/test_query.py b/tests/embedchain/test_query.py new file mode 100644 index 00000000..da84ffe7 --- /dev/null +++ b/tests/embedchain/test_query.py @@ -0,0 +1,43 @@ +import os +import unittest +from unittest.mock import MagicMock, patch + +from embedchain import App +from embedchain.embedchain import QueryConfig + + +class TestApp(unittest.TestCase): + os.environ["OPENAI_API_KEY"] = "test_key" + + def setUp(self): + self.app = App() + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_query(self): + """ + This test checks the functionality of the 'query' method in the App class. + It simulates a scenario where the 'retrieve_from_database' method returns a context list and + 'get_llm_model_answer' returns an expected answer string. + + The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods + appropriately and return the right answer. + + Key assumptions tested: + - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of + QueryConfig. + - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. + - 'query' method returns the value it received from 'get_llm_model_answer'. + + The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and + 'get_llm_model_answer' methods. + """ + with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = ["Test context"] + with patch.object(self.app, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + answer = self.app.query("Test query") + + self.assertEqual(answer, "Test answer") + self.assertEqual(mock_retrieve.call_args[0][0], "Test query") + self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig) + mock_answer.assert_called_once() diff --git a/tests/test_embedchain.py b/tests/test_embedchain.py deleted file mode 100644 index 3a60f10c..00000000 --- a/tests/test_embedchain.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import unittest -from unittest.mock import MagicMock, patch - -from embedchain import App - - -class TestApp(unittest.TestCase): - os.environ["OPENAI_API_KEY"] = "test_key" - - def setUp(self): - self.app = App() - - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_add(self): - self.app.add("web_page", "https://example.com") - self.assertEqual(self.app.user_asks, [["web_page", "https://example.com"]]) - - @patch("chromadb.api.models.Collection.Collection.add", MagicMock) - def test_query(self): - with patch.object(self.app, "retrieve_from_database") as mock_retrieve: - mock_retrieve.return_value = "Test context" - with patch.object(self.app, "get_llm_model_answer") as mock_answer: - mock_answer.return_value = "Test answer" - answer = self.app.query("Test query") - - self.assertEqual(answer, "Test answer") - mock_retrieve.assert_called_once_with("Test query") - mock_answer.assert_called_once() diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py new file mode 100644 index 00000000..48d34604 --- /dev/null +++ b/tests/vectordb/test_chroma_db.py @@ -0,0 +1,72 @@ +# ruff: noqa: E501 + +import unittest +from unittest.mock import patch + +from embedchain import App +from embedchain.config import InitConfig +from embedchain.vectordb.chroma_db import ChromaDB, chromadb + + +class TestChromaDbHosts(unittest.TestCase): + def test_init_with_host_and_port(self): + """ + Test if the `ChromaDB` instance is initialized with the correct host and port values. + """ + host = "test-host" + port = "1234" + + with patch.object(chromadb, "Client") as mock_client: + _db = ChromaDB(host=host, port=port) + + expected_settings = chromadb.config.Settings( + chroma_api_impl="rest", + chroma_server_host=host, + chroma_server_http_port=port, + ) + + mock_client.assert_called_once_with(expected_settings) + + +class TestChromaDbHostsInit(unittest.TestCase): + @patch("embedchain.vectordb.chroma_db.chromadb.Client") + def test_init_with_host_and_port(self, mock_client): + """ + Test if the `App` instance is initialized with the correct host and port values. + """ + host = "test-host" + port = "1234" + + config = InitConfig(host=host, port=port) + + _app = App(config) + + self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host) + self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port) + + +class TestChromaDbHostsNone(unittest.TestCase): + @patch("embedchain.vectordb.chroma_db.chromadb.Client") + def test_init_with_host_and_port(self, mock_client): + """ + Test if the `App` instance is initialized without default hosts and ports. + """ + + _app = App() + + self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None) + self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None) + + +class TestChromaDbHostsLoglevel(unittest.TestCase): + @patch("embedchain.vectordb.chroma_db.chromadb.Client") + def test_init_with_host_and_port(self, mock_client): + """ + Test if the `App` instance is initialized without a config that does not contain default hosts and ports. + """ + config = InitConfig(log_level="DEBUG") + + _app = App(config) + + self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None) + self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)