Refactor: Make it clear what methods are private (#946)

This commit is contained in:
UnMonsieur
2023-11-13 22:00:13 +01:00
committed by GitHub
parent a5bf8e9075
commit bf3fac56e4
8 changed files with 28 additions and 26 deletions

View File

@@ -16,8 +16,7 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
@@ -203,7 +202,7 @@ class EmbedChain(JSONSerializable):
self.user_asks.append([source, data_type.value, metadata])
data_formatter = DataFormatter(data_type, config, kwargs)
documents, metadatas, _ids, new_chunks = self.load_and_embed(
documents, metadatas, _ids, new_chunks = self._load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
)
if data_type in {DataType.DOCS_SITE}:
@@ -340,7 +339,7 @@ class EmbedChain(JSONSerializable):
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
)
def load_and_embed(
def _load_and_embed(
self,
loader: BaseLoader,
chunker: BaseChunker,
@@ -457,7 +456,7 @@ class EmbedChain(JSONSerializable):
)
]
def retrieve_from_database(
def _retrieve_from_database(
self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
) -> Union[List[Tuple[str, str, str]], List[str]]:
"""
@@ -537,7 +536,9 @@ class EmbedChain(JSONSerializable):
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
"""
citations = kwargs.get("citations", False)
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
else:
@@ -588,7 +589,9 @@ class EmbedChain(JSONSerializable):
:rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
"""
citations = kwargs.get("citations", False)
contexts = self.retrieve_from_database(input_query=input_query, config=config, where=where, citations=citations)
contexts = self._retrieve_from_database(
input_query=input_query, config=config, where=where, citations=citations
)
if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
else: