[Improvements] Package improvements (#993)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-05 23:42:45 -08:00
committed by GitHub
parent 1d4e00ccef
commit 51b4966801
13 changed files with 96 additions and 40 deletions

View File

@@ -1,5 +1,5 @@
from importlib import import_module
from typing import Any, Dict
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig
@@ -16,7 +16,13 @@ class DataFormatter(JSONSerializable):
.add or .add_local method call
"""
def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]):
def __init__(
self,
data_type: DataType,
config: AddConfig,
loader: Optional[BaseLoader] = None,
chunker: Optional[BaseChunker] = None,
):
"""
Initialize a dataformatter, set data type and chunker based on datatype.
@@ -25,15 +31,15 @@ class DataFormatter(JSONSerializable):
:param config: AddConfig instance with nested loader and chunker config attributes.
:type config: AddConfig
"""
self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs)
self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
def _lazy_load(self, module_path: str):
module_path, class_name = module_path.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader:
def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
"""
Returns the appropriate data loader for the given data type.
@@ -68,8 +74,8 @@ class DataFormatter(JSONSerializable):
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
}
if data_type == DataType.CUSTOM or ("loader" in kwargs):
loader_class: type = kwargs.get("loader", None)
if data_type == DataType.CUSTOM or loader is not None:
loader_class: type = loader
if loader_class:
return loader_class
elif data_type in loaders:
@@ -82,7 +88,7 @@ class DataFormatter(JSONSerializable):
check `https://docs.embedchain.ai/data-sources/overview`."
)
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker:
def _get_chunker(self, data_type: DataType, config: ChunkerConfig, chunker: Optional[BaseChunker]) -> BaseChunker:
"""Returns the appropriate chunker for the given data type (updated for lazy loading)."""
chunker_classes = {
DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
@@ -108,12 +114,8 @@ class DataFormatter(JSONSerializable):
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
}
if "chunker" in kwargs:
chunker_class = kwargs.get("chunker", None)
if chunker_class:
chunker = chunker_class(config)
chunker.set_data_type(data_type)
return chunker
if chunker is not None:
return chunker
elif data_type in chunker_classes:
chunker_class = self._lazy_load(chunker_classes[data_type])
chunker = chunker_class(config)