From bf3fac56e4a76ad774b1dd8804fda2e00bb89932 Mon Sep 17 00:00:00 2001 From: UnMonsieur Date: Mon, 13 Nov 2023 22:00:13 +0100 Subject: [PATCH] Refactor: Make it clear what methods are private (#946) --- embedchain/embedchain.py | 17 ++++++++++------- embedchain/helper/json_serializable.py | 4 ++-- embedchain/telemetry/posthog.py | 4 ++-- tests/embedchain/test_embedchain.py | 2 +- tests/helper_classes/test_json_serializable.py | 5 ++--- tests/llm/test_chat.py | 16 ++++++++-------- tests/llm/test_query.py | 4 ++-- tests/telemetry/test_posthog.py | 2 +- 8 files changed, 28 insertions(+), 26 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 991397cb..67b4770c 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -16,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder from embedchain.helper.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import (DataType, DirectDataType, - IndirectDataType, SpecialDataType) +from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.vectordb.base import BaseVectorDB @@ -203,7 +202,7 @@ class EmbedChain(JSONSerializable): self.user_asks.append([source, data_type.value, metadata]) data_formatter = DataFormatter(data_type, config, kwargs) - documents, metadatas, _ids, new_chunks = self.load_and_embed( + documents, metadatas, _ids, new_chunks = self._load_and_embed( data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run ) if data_type in {DataType.DOCS_SITE}: @@ -340,7 +339,7 @@ class EmbedChain(JSONSerializable): "When it should be DirectDataType, IndirectDataType or SpecialDataType." ) - def load_and_embed( + def _load_and_embed( self, loader: BaseLoader, chunker: BaseChunker, @@ -457,7 +456,7 @@ class EmbedChain(JSONSerializable): ) ] - def retrieve_from_database( + def _retrieve_from_database( self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False ) -> Union[List[Tuple[str, str, str]], List[str]]: """ @@ -537,7 +536,9 @@ class EmbedChain(JSONSerializable): :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ citations = kwargs.get("citations", False) - contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations) + contexts = self._retrieve_from_database( + input_query=input_query, config=config, where=where, citations=citations + ) if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) else: @@ -588,7 +589,9 @@ class EmbedChain(JSONSerializable): :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]] """ citations = kwargs.get("citations", False) - contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations) + contexts = self._retrieve_from_database( + input_query=input_query, config=config, where=where, citations=citations + ) if citations and len(contexts) > 0 and isinstance(contexts[0], tuple): contexts_data_for_llm_query = list(map(lambda x: x[0], contexts)) else: diff --git a/embedchain/helper/json_serializable.py b/embedchain/helper/json_serializable.py index b5fb7c41..d63cff74 100644 --- a/embedchain/helper/json_serializable.py +++ b/embedchain/helper/json_serializable.py @@ -33,7 +33,7 @@ def register_deserializable(cls: Type[T]) -> Type[T]: Returns: Type: The same class, after registration. """ - JSONSerializable.register_class_as_deserializable(cls) + JSONSerializable._register_class_as_deserializable(cls) return cls @@ -183,7 +183,7 @@ class JSONSerializable: return cls.deserialize(json_str) @classmethod - def register_class_as_deserializable(cls, target_class: Type[T]) -> None: + def _register_class_as_deserializable(cls, target_class: Type[T]) -> None: """ Register a class as deserializable. This is a classmethod and globally shared. diff --git a/embedchain/telemetry/posthog.py b/embedchain/telemetry/posthog.py index 785e9ed2..174c87ac 100644 --- a/embedchain/telemetry/posthog.py +++ b/embedchain/telemetry/posthog.py @@ -20,7 +20,7 @@ class AnonymousTelemetry: self.project_api_key = "phc_PHQDA5KwztijnSojsxJ2c1DuJd52QCzJzT2xnSGvjN2" self.host = host self.posthog = Posthog(project_api_key=self.project_api_key, host=self.host) - self.user_id = self.get_user_id() + self.user_id = self._get_user_id() self.enabled = enabled # Check if telemetry tracking is disabled via environment variable @@ -38,7 +38,7 @@ class AnonymousTelemetry: posthog_logger = logging.getLogger("posthog") posthog_logger.disabled = True - def get_user_id(self): + def _get_user_id(self): if not os.path.exists(CONFIG_DIR): os.makedirs(CONFIG_DIR) diff --git a/tests/embedchain/test_embedchain.py b/tests/embedchain/test_embedchain.py index ed1c4f64..a53e6ec3 100644 --- a/tests/embedchain/test_embedchain.py +++ b/tests/embedchain/test_embedchain.py @@ -22,7 +22,7 @@ def test_whole_app(app_instance, mocker): knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing" mocker.patch.object(EmbedChain, "add") - mocker.patch.object(EmbedChain, "retrieve_from_database") + mocker.patch.object(EmbedChain, "_retrieve_from_database") mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge) mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge) mocker.patch.object(BaseLlm, "generate_prompt") diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py index 48345fdb..9d2768b0 100644 --- a/tests/helper_classes/test_json_serializable.py +++ b/tests/helper_classes/test_json_serializable.py @@ -4,8 +4,7 @@ from string import Template from embedchain import App from embedchain.config import AppConfig, BaseLlmConfig -from embedchain.helper.json_serializable import (JSONSerializable, - register_deserializable) +from embedchain.helper.json_serializable import JSONSerializable, register_deserializable class TestJsonSerializable(unittest.TestCase): @@ -53,7 +52,7 @@ class TestJsonSerializable(unittest.TestCase): app: SecondTestClass = SecondTestClass().deserialize(serial) self.assertTrue(app.default) # If we register and try again with the same serial, it should work - SecondTestClass.register_class_as_deserializable(SecondTestClass) + SecondTestClass._register_class_as_deserializable(SecondTestClass) app: SecondTestClass = SecondTestClass().deserialize(serial) self.assertFalse(app.default) diff --git a/tests/llm/test_chat.py b/tests/llm/test_chat.py index a70e62ce..0b3e30aa 100644 --- a/tests/llm/test_chat.py +++ b/tests/llm/test_chat.py @@ -14,7 +14,7 @@ class TestApp(unittest.TestCase): os.environ["OPENAI_API_KEY"] = "test_key" self.app = App(config=AppConfig(collect_metrics=False)) - @patch.object(App, "retrieve_from_database", return_value=["Test context"]) + @patch.object(App, "_retrieve_from_database", return_value=["Test context"]) @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer") def test_chat_with_memory(self, mock_get_answer, mock_retrieve): """ @@ -28,7 +28,7 @@ class TestApp(unittest.TestCase): - After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are - During the second call, the 'chat' method uses the chat history from the first call. - The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and + The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database', 'get_answer_from_llm' and 'memory' methods. """ config = AppConfig(collect_metrics=False) @@ -42,7 +42,7 @@ class TestApp(unittest.TestCase): self.assertEqual(second_answer, "Test answer") mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer") - @patch.object(App, "retrieve_from_database", return_value=["Test context"]) + @patch.object(App, "_retrieve_from_database", return_value=["Test context"]) @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer") def test_template_replacement(self, mock_get_answer, mock_retrieve): """ @@ -73,7 +73,7 @@ class TestApp(unittest.TestCase): """ Test where filter """ - with patch.object(self.app, "retrieve_from_database") as mock_retrieve: + with patch.object(self.app, "_retrieve_from_database") as mock_retrieve: mock_retrieve.return_value = ["Test context"] with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: mock_answer.return_value = "Test answer" @@ -89,19 +89,19 @@ class TestApp(unittest.TestCase): def test_chat_with_where_in_chat_config(self): """ This test checks the functionality of the 'chat' method in the App class. - It simulates a scenario where the 'retrieve_from_database' method returns a context list based on + It simulates a scenario where the '_retrieve_from_database' method returns a context list based on a where filter and 'get_llm_model_answer' returns an expected answer string. - The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified + The 'chat' method is expected to call '_retrieve_from_database' with the where filter specified in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer. Key assumptions tested: - - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of + - '_retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of BaseLlmConfig. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. - 'chat' method returns the value it received from 'get_llm_model_answer'. - The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and + The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database' and 'get_llm_model_answer' methods. """ with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer: diff --git a/tests/llm/test_query.py b/tests/llm/test_query.py index 9ebbecd4..ca3b17e6 100644 --- a/tests/llm/test_query.py +++ b/tests/llm/test_query.py @@ -16,7 +16,7 @@ def app(): @patch("chromadb.api.models.Collection.Collection.add", MagicMock) def test_query(app): - with patch.object(app, "retrieve_from_database") as mock_retrieve: + with patch.object(app, "_retrieve_from_database") as mock_retrieve: mock_retrieve.return_value = ["Test context"] with patch.object(app.llm, "get_llm_model_answer") as mock_answer: mock_answer.return_value = "Test answer" @@ -58,7 +58,7 @@ def test_app_passing(mock_get_answer): @patch("chromadb.api.models.Collection.Collection.add", MagicMock) def test_query_with_where_in_params(app): - with patch.object(app, "retrieve_from_database") as mock_retrieve: + with patch.object(app, "_retrieve_from_database") as mock_retrieve: mock_retrieve.return_value = ["Test context"] with patch.object(app.llm, "get_llm_model_answer") as mock_answer: mock_answer.return_value = "Test answer" diff --git a/tests/telemetry/test_posthog.py b/tests/telemetry/test_posthog.py index 1aa06740..f430a9c9 100644 --- a/tests/telemetry/test_posthog.py +++ b/tests/telemetry/test_posthog.py @@ -29,7 +29,7 @@ class TestAnonymousTelemetry: mocker.patch("embedchain.telemetry.posthog.CONFIG_FILE", str(config_file)) telemetry = AnonymousTelemetry() - user_id = telemetry.get_user_id() + user_id = telemetry._get_user_id() assert user_id == "unique_user_id" assert config_file.read() == '{"user_id": "unique_user_id"}'