Refactor: Make it clear what methods are private (#946)
This commit is contained in:
@@ -16,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
|
|||||||
from embedchain.helper.json_serializable import JSONSerializable
|
from embedchain.helper.json_serializable import JSONSerializable
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||||
IndirectDataType, SpecialDataType)
|
|
||||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||||
from embedchain.utils import detect_datatype, is_valid_json_string
|
from embedchain.utils import detect_datatype, is_valid_json_string
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
@@ -203,7 +202,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
self.user_asks.append([source, data_type.value, metadata])
|
self.user_asks.append([source, data_type.value, metadata])
|
||||||
|
|
||||||
data_formatter = DataFormatter(data_type, config, kwargs)
|
data_formatter = DataFormatter(data_type, config, kwargs)
|
||||||
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
documents, metadatas, _ids, new_chunks = self._load_and_embed(
|
||||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
||||||
)
|
)
|
||||||
if data_type in {DataType.DOCS_SITE}:
|
if data_type in {DataType.DOCS_SITE}:
|
||||||
@@ -340,7 +339,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
|
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_and_embed(
|
def _load_and_embed(
|
||||||
self,
|
self,
|
||||||
loader: BaseLoader,
|
loader: BaseLoader,
|
||||||
chunker: BaseChunker,
|
chunker: BaseChunker,
|
||||||
@@ -457,7 +456,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def retrieve_from_database(
|
def _retrieve_from_database(
|
||||||
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
|
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
|
||||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||||
"""
|
"""
|
||||||
@@ -537,7 +536,9 @@ class EmbedChain(JSONSerializable):
|
|||||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||||
"""
|
"""
|
||||||
citations = kwargs.get("citations", False)
|
citations = kwargs.get("citations", False)
|
||||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
contexts = self._retrieve_from_database(
|
||||||
|
input_query=input_query, config=config, where=where, citations=citations
|
||||||
|
)
|
||||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||||
else:
|
else:
|
||||||
@@ -588,7 +589,9 @@ class EmbedChain(JSONSerializable):
|
|||||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||||
"""
|
"""
|
||||||
citations = kwargs.get("citations", False)
|
citations = kwargs.get("citations", False)
|
||||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
contexts = self._retrieve_from_database(
|
||||||
|
input_query=input_query, config=config, where=where, citations=citations
|
||||||
|
)
|
||||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def register_deserializable(cls: Type[T]) -> Type[T]:
|
|||||||
Returns:
|
Returns:
|
||||||
Type: The same class, after registration.
|
Type: The same class, after registration.
|
||||||
"""
|
"""
|
||||||
JSONSerializable.register_class_as_deserializable(cls)
|
JSONSerializable._register_class_as_deserializable(cls)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ class JSONSerializable:
|
|||||||
return cls.deserialize(json_str)
|
return cls.deserialize(json_str)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_class_as_deserializable(cls, target_class: Type[T]) -> None:
|
def _register_class_as_deserializable(cls, target_class: Type[T]) -> None:
|
||||||
"""
|
"""
|
||||||
Register a class as deserializable. This is a classmethod and globally shared.
|
Register a class as deserializable. This is a classmethod and globally shared.
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class AnonymousTelemetry:
|
|||||||
self.project_api_key = "phc_PHQDA5KwztijnSojsxJ2c1DuJd52QCzJzT2xnSGvjN2"
|
self.project_api_key = "phc_PHQDA5KwztijnSojsxJ2c1DuJd52QCzJzT2xnSGvjN2"
|
||||||
self.host = host
|
self.host = host
|
||||||
self.posthog = Posthog(project_api_key=self.project_api_key, host=self.host)
|
self.posthog = Posthog(project_api_key=self.project_api_key, host=self.host)
|
||||||
self.user_id = self.get_user_id()
|
self.user_id = self._get_user_id()
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
|
|
||||||
# Check if telemetry tracking is disabled via environment variable
|
# Check if telemetry tracking is disabled via environment variable
|
||||||
@@ -38,7 +38,7 @@ class AnonymousTelemetry:
|
|||||||
posthog_logger = logging.getLogger("posthog")
|
posthog_logger = logging.getLogger("posthog")
|
||||||
posthog_logger.disabled = True
|
posthog_logger.disabled = True
|
||||||
|
|
||||||
def get_user_id(self):
|
def _get_user_id(self):
|
||||||
if not os.path.exists(CONFIG_DIR):
|
if not os.path.exists(CONFIG_DIR):
|
||||||
os.makedirs(CONFIG_DIR)
|
os.makedirs(CONFIG_DIR)
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def test_whole_app(app_instance, mocker):
|
|||||||
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
|
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
|
||||||
|
|
||||||
mocker.patch.object(EmbedChain, "add")
|
mocker.patch.object(EmbedChain, "add")
|
||||||
mocker.patch.object(EmbedChain, "retrieve_from_database")
|
mocker.patch.object(EmbedChain, "_retrieve_from_database")
|
||||||
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
|
||||||
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
|
||||||
mocker.patch.object(BaseLlm, "generate_prompt")
|
mocker.patch.object(BaseLlm, "generate_prompt")
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ from string import Template
|
|||||||
|
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import AppConfig, BaseLlmConfig
|
from embedchain.config import AppConfig, BaseLlmConfig
|
||||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
|
||||||
register_deserializable)
|
|
||||||
|
|
||||||
|
|
||||||
class TestJsonSerializable(unittest.TestCase):
|
class TestJsonSerializable(unittest.TestCase):
|
||||||
@@ -53,7 +52,7 @@ class TestJsonSerializable(unittest.TestCase):
|
|||||||
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||||
self.assertTrue(app.default)
|
self.assertTrue(app.default)
|
||||||
# If we register and try again with the same serial, it should work
|
# If we register and try again with the same serial, it should work
|
||||||
SecondTestClass.register_class_as_deserializable(SecondTestClass)
|
SecondTestClass._register_class_as_deserializable(SecondTestClass)
|
||||||
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||||
self.assertFalse(app.default)
|
self.assertFalse(app.default)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class TestApp(unittest.TestCase):
|
|||||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||||
self.app = App(config=AppConfig(collect_metrics=False))
|
self.app = App(config=AppConfig(collect_metrics=False))
|
||||||
|
|
||||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
|
||||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||||
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
|
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
|
||||||
"""
|
"""
|
||||||
@@ -28,7 +28,7 @@ class TestApp(unittest.TestCase):
|
|||||||
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
|
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
|
||||||
- During the second call, the 'chat' method uses the chat history from the first call.
|
- During the second call, the 'chat' method uses the chat history from the first call.
|
||||||
|
|
||||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
|
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database', 'get_answer_from_llm' and
|
||||||
'memory' methods.
|
'memory' methods.
|
||||||
"""
|
"""
|
||||||
config = AppConfig(collect_metrics=False)
|
config = AppConfig(collect_metrics=False)
|
||||||
@@ -42,7 +42,7 @@ class TestApp(unittest.TestCase):
|
|||||||
self.assertEqual(second_answer, "Test answer")
|
self.assertEqual(second_answer, "Test answer")
|
||||||
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
|
mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
|
||||||
|
|
||||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
@patch.object(App, "_retrieve_from_database", return_value=["Test context"])
|
||||||
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
|
||||||
def test_template_replacement(self, mock_get_answer, mock_retrieve):
|
def test_template_replacement(self, mock_get_answer, mock_retrieve):
|
||||||
"""
|
"""
|
||||||
@@ -73,7 +73,7 @@ class TestApp(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
Test where filter
|
Test where filter
|
||||||
"""
|
"""
|
||||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
with patch.object(self.app, "_retrieve_from_database") as mock_retrieve:
|
||||||
mock_retrieve.return_value = ["Test context"]
|
mock_retrieve.return_value = ["Test context"]
|
||||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||||
mock_answer.return_value = "Test answer"
|
mock_answer.return_value = "Test answer"
|
||||||
@@ -89,19 +89,19 @@ class TestApp(unittest.TestCase):
|
|||||||
def test_chat_with_where_in_chat_config(self):
|
def test_chat_with_where_in_chat_config(self):
|
||||||
"""
|
"""
|
||||||
This test checks the functionality of the 'chat' method in the App class.
|
This test checks the functionality of the 'chat' method in the App class.
|
||||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
It simulates a scenario where the '_retrieve_from_database' method returns a context list based on
|
||||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||||
|
|
||||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
|
The 'chat' method is expected to call '_retrieve_from_database' with the where filter specified
|
||||||
in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||||
|
|
||||||
Key assumptions tested:
|
Key assumptions tested:
|
||||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
- '_retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||||
BaseLlmConfig.
|
BaseLlmConfig.
|
||||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||||
|
|
||||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database' and
|
||||||
'get_llm_model_answer' methods.
|
'get_llm_model_answer' methods.
|
||||||
"""
|
"""
|
||||||
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ def app():
|
|||||||
|
|
||||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
def test_query(app):
|
def test_query(app):
|
||||||
with patch.object(app, "retrieve_from_database") as mock_retrieve:
|
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||||
mock_retrieve.return_value = ["Test context"]
|
mock_retrieve.return_value = ["Test context"]
|
||||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||||
mock_answer.return_value = "Test answer"
|
mock_answer.return_value = "Test answer"
|
||||||
@@ -58,7 +58,7 @@ def test_app_passing(mock_get_answer):
|
|||||||
|
|
||||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
def test_query_with_where_in_params(app):
|
def test_query_with_where_in_params(app):
|
||||||
with patch.object(app, "retrieve_from_database") as mock_retrieve:
|
with patch.object(app, "_retrieve_from_database") as mock_retrieve:
|
||||||
mock_retrieve.return_value = ["Test context"]
|
mock_retrieve.return_value = ["Test context"]
|
||||||
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
with patch.object(app.llm, "get_llm_model_answer") as mock_answer:
|
||||||
mock_answer.return_value = "Test answer"
|
mock_answer.return_value = "Test answer"
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class TestAnonymousTelemetry:
|
|||||||
mocker.patch("embedchain.telemetry.posthog.CONFIG_FILE", str(config_file))
|
mocker.patch("embedchain.telemetry.posthog.CONFIG_FILE", str(config_file))
|
||||||
telemetry = AnonymousTelemetry()
|
telemetry = AnonymousTelemetry()
|
||||||
|
|
||||||
user_id = telemetry.get_user_id()
|
user_id = telemetry._get_user_id()
|
||||||
assert user_id == "unique_user_id"
|
assert user_id == "unique_user_id"
|
||||||
assert config_file.read() == '{"user_id": "unique_user_id"}'
|
assert config_file.read() == '{"user_id": "unique_user_id"}'
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user