From 8674297d1afe70a0477462da814f7857a019474b Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Mon, 10 Jul 2023 05:12:29 -0700 Subject: [PATCH] feat: Add basic unit test for embedchain (#209) --- tests/__init__.py | 0 tests/test_embedchain.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_embedchain.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_embedchain.py b/tests/test_embedchain.py new file mode 100644 index 00000000..b0534e03 --- /dev/null +++ b/tests/test_embedchain.py @@ -0,0 +1,29 @@ +import os + +import unittest +from unittest.mock import patch, MagicMock +from embedchain import App + + +class TestApp(unittest.TestCase): + os.environ["OPENAI_API_KEY"] = "test_key" + + def setUp(self): + self.app = App() + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_add(self): + self.app.add("web_page", "https://example.com") + self.assertEqual(self.app.user_asks, [["web_page", "https://example.com"]]) + + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_query(self): + with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + mock_retrieve.return_value = "Test context" + with patch.object(self.app, "get_llm_model_answer") as mock_answer: + mock_answer.return_value = "Test answer" + answer = self.app.query("Test query") + + self.assertEqual(answer, "Test answer") + mock_retrieve.assert_called_once_with("Test query") + mock_answer.assert_called_once()