Refactor: Make it clear what methods are private (#946)
This commit is contained in:
@@ -16,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
@@ -203,7 +202,7 @@ class EmbedChain(JSONSerializable):
|
||||
self.user_asks.append([source, data_type.value, metadata])
|
||||
|
||||
data_formatter = DataFormatter(data_type, config, kwargs)
|
||||
documents, metadatas, _ids, new_chunks = self.load_and_embed(
|
||||
documents, metadatas, _ids, new_chunks = self._load_and_embed(
|
||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
|
||||
)
|
||||
if data_type in {DataType.DOCS_SITE}:
|
||||
@@ -340,7 +339,7 @@ class EmbedChain(JSONSerializable):
|
||||
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
|
||||
)
|
||||
|
||||
def load_and_embed(
|
||||
def _load_and_embed(
|
||||
self,
|
||||
loader: BaseLoader,
|
||||
chunker: BaseChunker,
|
||||
@@ -457,7 +456,7 @@ class EmbedChain(JSONSerializable):
|
||||
)
|
||||
]
|
||||
|
||||
def retrieve_from_database(
|
||||
def _retrieve_from_database(
|
||||
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
|
||||
) -> Union[List[Tuple[str, str, str]], List[str]]:
|
||||
"""
|
||||
@@ -537,7 +536,9 @@ class EmbedChain(JSONSerializable):
|
||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||
"""
|
||||
citations = kwargs.get("citations", False)
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations
|
||||
)
|
||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||
else:
|
||||
@@ -588,7 +589,9 @@ class EmbedChain(JSONSerializable):
|
||||
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
|
||||
"""
|
||||
citations = kwargs.get("citations", False)
|
||||
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
|
||||
contexts = self._retrieve_from_database(
|
||||
input_query=input_query, config=config, where=where, citations=citations
|
||||
)
|
||||
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
|
||||
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
|
||||
else:
|
||||
|
||||
@@ -33,7 +33,7 @@ def register_deserializable(cls: Type[T]) -> Type[T]:
|
||||
Returns:
|
||||
Type: The same class, after registration.
|
||||
"""
|
||||
JSONSerializable.register_class_as_deserializable(cls)
|
||||
JSONSerializable._register_class_as_deserializable(cls)
|
||||
return cls
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ class JSONSerializable:
|
||||
return cls.deserialize(json_str)
|
||||
|
||||
@classmethod
|
||||
def register_class_as_deserializable(cls, target_class: Type[T]) -> None:
|
||||
def _register_class_as_deserializable(cls, target_class: Type[T]) -> None:
|
||||
"""
|
||||
Register a class as deserializable. This is a classmethod and globally shared.
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class AnonymousTelemetry:
|
||||
self.project_api_key = "phc_PHQDA5KwztijnSojsxJ2c1DuJd52QCzJzT2xnSGvjN2"
|
||||
self.host = host
|
||||
self.posthog = Posthog(project_api_key=self.project_api_key, host=self.host)
|
||||
self.user_id = self.get_user_id()
|
||||
self.user_id = self._get_user_id()
|
||||
self.enabled = enabled
|
||||
|
||||
# Check if telemetry tracking is disabled via environment variable
|
||||
@@ -38,7 +38,7 @@ class AnonymousTelemetry:
|
||||
posthog_logger = logging.getLogger("posthog")
|
||||
posthog_logger.disabled = True
|
||||
|
||||
def get_user_id(self):
|
||||
def _get_user_id(self):
|
||||
if not os.path.exists(CONFIG_DIR):
|
||||
os.makedirs(CONFIG_DIR)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user