feat: add method - detect format / data_type (#380)
This commit is contained in:
@@ -15,6 +15,7 @@ from embedchain.loaders.pdf_file import PdfFileLoader
|
||||
from embedchain.loaders.sitemap import SitemapLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class DataFormatter:
|
||||
@@ -24,11 +25,11 @@ class DataFormatter:
|
||||
.add or .add_local method call
|
||||
"""
|
||||
|
||||
def __init__(self, data_type: str, config: AddConfig):
|
||||
def __init__(self, data_type: DataType, config: AddConfig):
|
||||
self.loader = self._get_loader(data_type, config.loader)
|
||||
self.chunker = self._get_chunker(data_type, config.chunker)
|
||||
|
||||
def _get_loader(self, data_type, config):
|
||||
def _get_loader(self, data_type: DataType, config):
|
||||
"""
|
||||
Returns the appropriate data loader for the given data type.
|
||||
|
||||
@@ -37,22 +38,22 @@ class DataFormatter:
|
||||
:raises ValueError: If an unsupported data type is provided.
|
||||
"""
|
||||
loaders = {
|
||||
"youtube_video": YoutubeVideoLoader,
|
||||
"pdf_file": PdfFileLoader,
|
||||
"web_page": WebPageLoader,
|
||||
"qna_pair": LocalQnaPairLoader,
|
||||
"text": LocalTextLoader,
|
||||
"docx": DocxFileLoader,
|
||||
"sitemap": SitemapLoader,
|
||||
"docs_site": DocsSiteLoader,
|
||||
DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
|
||||
DataType.PDF_FILE: PdfFileLoader,
|
||||
DataType.WEB_PAGE: WebPageLoader,
|
||||
DataType.QNA_PAIR: LocalQnaPairLoader,
|
||||
DataType.TEXT: LocalTextLoader,
|
||||
DataType.DOCX: DocxFileLoader,
|
||||
DataType.SITEMAP: SitemapLoader,
|
||||
DataType.DOCS_SITE: DocsSiteLoader,
|
||||
}
|
||||
lazy_loaders = ("notion",)
|
||||
lazy_loaders = {DataType.NOTION}
|
||||
if data_type in loaders:
|
||||
loader_class = loaders[data_type]
|
||||
loader = loader_class()
|
||||
return loader
|
||||
elif data_type in lazy_loaders:
|
||||
if data_type == "notion":
|
||||
if data_type == DataType.NOTION:
|
||||
from embedchain.loaders.notion import NotionLoader
|
||||
|
||||
return NotionLoader()
|
||||
@@ -61,7 +62,7 @@ class DataFormatter:
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
def _get_chunker(self, data_type, config):
|
||||
def _get_chunker(self, data_type: DataType, config):
|
||||
"""
|
||||
Returns the appropriate chunker for the given data type.
|
||||
|
||||
@@ -70,15 +71,15 @@ class DataFormatter:
|
||||
:raises ValueError: If an unsupported data type is provided.
|
||||
"""
|
||||
chunker_classes = {
|
||||
"youtube_video": YoutubeVideoChunker,
|
||||
"pdf_file": PdfFileChunker,
|
||||
"web_page": WebPageChunker,
|
||||
"qna_pair": QnaPairChunker,
|
||||
"text": TextChunker,
|
||||
"docx": DocxFileChunker,
|
||||
"sitemap": WebPageChunker,
|
||||
"docs_site": DocsSiteChunker,
|
||||
"notion": NotionChunker,
|
||||
DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
|
||||
DataType.PDF_FILE: PdfFileChunker,
|
||||
DataType.WEB_PAGE: WebPageChunker,
|
||||
DataType.QNA_PAIR: QnaPairChunker,
|
||||
DataType.TEXT: TextChunker,
|
||||
DataType.DOCX: DocxFileChunker,
|
||||
DataType.WEB_PAGE: WebPageChunker,
|
||||
DataType.DOCS_SITE: DocsSiteChunker,
|
||||
DataType.NOTION: NotionChunker,
|
||||
}
|
||||
if data_type in chunker_classes:
|
||||
chunker_class = chunker_classes[data_type]
|
||||
|
||||
Reference in New Issue
Block a user