feat: add method - detect format / data_type (#380)

This commit is contained in:
cachho
2023-08-16 22:18:24 +02:00
committed by GitHub
parent f92e890aa1
commit 4c8876f032
18 changed files with 472 additions and 121 deletions

View File

@@ -15,6 +15,7 @@ from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.sitemap import SitemapLoader
from embedchain.loaders.web_page import WebPageLoader
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.models.data_type import DataType
class DataFormatter:
@@ -24,11 +25,11 @@ class DataFormatter:
.add or .add_local method call
"""
def __init__(self, data_type: str, config: AddConfig):
def __init__(self, data_type: DataType, config: AddConfig):
self.loader = self._get_loader(data_type, config.loader)
self.chunker = self._get_chunker(data_type, config.chunker)
def _get_loader(self, data_type, config):
def _get_loader(self, data_type: DataType, config):
"""
Returns the appropriate data loader for the given data type.
@@ -37,22 +38,22 @@ class DataFormatter:
:raises ValueError: If an unsupported data type is provided.
"""
loaders = {
"youtube_video": YoutubeVideoLoader,
"pdf_file": PdfFileLoader,
"web_page": WebPageLoader,
"qna_pair": LocalQnaPairLoader,
"text": LocalTextLoader,
"docx": DocxFileLoader,
"sitemap": SitemapLoader,
"docs_site": DocsSiteLoader,
DataType.YOUTUBE_VIDEO: YoutubeVideoLoader,
DataType.PDF_FILE: PdfFileLoader,
DataType.WEB_PAGE: WebPageLoader,
DataType.QNA_PAIR: LocalQnaPairLoader,
DataType.TEXT: LocalTextLoader,
DataType.DOCX: DocxFileLoader,
DataType.SITEMAP: SitemapLoader,
DataType.DOCS_SITE: DocsSiteLoader,
}
lazy_loaders = ("notion",)
lazy_loaders = {DataType.NOTION}
if data_type in loaders:
loader_class = loaders[data_type]
loader = loader_class()
return loader
elif data_type in lazy_loaders:
if data_type == "notion":
if data_type == DataType.NOTION:
from embedchain.loaders.notion import NotionLoader
return NotionLoader()
@@ -61,7 +62,7 @@ class DataFormatter:
else:
raise ValueError(f"Unsupported data type: {data_type}")
def _get_chunker(self, data_type, config):
def _get_chunker(self, data_type: DataType, config):
"""
Returns the appropriate chunker for the given data type.
@@ -70,15 +71,15 @@ class DataFormatter:
:raises ValueError: If an unsupported data type is provided.
"""
chunker_classes = {
"youtube_video": YoutubeVideoChunker,
"pdf_file": PdfFileChunker,
"web_page": WebPageChunker,
"qna_pair": QnaPairChunker,
"text": TextChunker,
"docx": DocxFileChunker,
"sitemap": WebPageChunker,
"docs_site": DocsSiteChunker,
"notion": NotionChunker,
DataType.YOUTUBE_VIDEO: YoutubeVideoChunker,
DataType.PDF_FILE: PdfFileChunker,
DataType.WEB_PAGE: WebPageChunker,
DataType.QNA_PAIR: QnaPairChunker,
DataType.TEXT: TextChunker,
DataType.DOCX: DocxFileChunker,
DataType.WEB_PAGE: WebPageChunker,
DataType.DOCS_SITE: DocsSiteChunker,
DataType.NOTION: NotionChunker,
}
if data_type in chunker_classes:
chunker_class = chunker_classes[data_type]