[Feature] Add support for directory loader as data source (#1008)

This commit is contained in:
Sidharth Mohanty
2023-12-15 05:24:34 +05:30
committed by GitHub
parent d54cdc5b00
commit 9303a1bf81
5 changed files with 69 additions and 5 deletions

View File

@@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable):
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
}
if data_type == DataType.CUSTOM or loader is not None:
@@ -116,6 +117,7 @@ class DataFormatter(JSONSerializable):
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
}
if chunker is not None:

View File

@@ -0,0 +1,55 @@
from pathlib import Path
import hashlib
import logging
from typing import Optional, Dict, Any
from embedchain.utils import detect_datatype
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.local_text import LocalTextLoader
from embedchain.data_formatter.data_formatter import DataFormatter
from embedchain.config import AddConfig
@register_deserializable
class DirectoryLoader(BaseLoader):
"""Load data from a directory."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
config = config or {}
self.recursive = config.get("recursive", True)
self.extensions = config.get("extensions", None)
self.errors = []
def load_data(self, path: str):
directory_path = Path(path)
if not directory_path.is_dir():
raise ValueError(f"Invalid path: {path}")
data_list = self._process_directory(directory_path)
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
for error in self.errors:
logging.warn(error)
return {"doc_id": doc_id, "data": data_list}
def _process_directory(self, directory_path: Path):
data_list = []
for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"):
if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)):
loader = self._predict_loader(file_path)
data_list.extend(loader.load_data(str(file_path))["data"])
return data_list
def _predict_loader(self, file_path: Path) -> BaseLoader:
try:
data_type = detect_datatype(str(file_path))
config = AddConfig()
return DataFormatter(data_type=data_type, config=config)._get_loader(
data_type=data_type, config=config.loader, loader=None
)
except Exception as e:
self.errors.append(f"Error processing {file_path}: {e}")
return LocalTextLoader()

View File

@@ -35,6 +35,7 @@ class IndirectDataType(Enum):
CUSTOM = "custom"
RSSFEED = "rss_feed"
BEEHIIV = "beehiiv"
DIRECTORY = "directory"
class SpecialDataType(Enum):
@@ -69,3 +70,4 @@ class DataType(Enum):
CUSTOM = IndirectDataType.CUSTOM.value
RSSFEED = IndirectDataType.RSSFEED.value
BEEHIIV = IndirectDataType.BEEHIIV.value
DIRECTORY = IndirectDataType.DIRECTORY.value

View File

@@ -196,8 +196,7 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30)
if url:
from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -303,6 +302,14 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
return DataType.MDX
if source.endswith(".txt"):
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
return DataType.TEXT
if source.endswith(".pdf"):
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
return DataType.PDF_FILE
if source.endswith(".yaml"):
with open(source, "r") as file:
yaml_content = yaml.safe_load(file)

View File

@@ -86,11 +86,9 @@ class TestApp(unittest.TestCase):
@patch("os.path.isfile")
def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile):
"""Test error if a valid file is referenced, but it isn't a valid data_type"""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
mock_isfile.return_value = True
with self.assertRaises(ValueError):
detect_datatype(tmp.name)
self.assertEqual(detect_datatype(tmp.name), DataType.TEXT)
def test_detect_datatype_regular_filesystem_no_file(self):
"""Test that if a filepath is not actually an existing file, it is not handled as a file path."""