diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 47caa0ea..0187bc40 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable): DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader", DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader", DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader", + DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader", } if data_type == DataType.CUSTOM or loader is not None: @@ -116,6 +117,7 @@ class DataFormatter(JSONSerializable): DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker", DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker", DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker", + DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker", } if chunker is not None: diff --git a/embedchain/loaders/directory_loader.py b/embedchain/loaders/directory_loader.py new file mode 100644 index 00000000..bc670bc6 --- /dev/null +++ b/embedchain/loaders/directory_loader.py @@ -0,0 +1,55 @@ +from pathlib import Path +import hashlib +import logging +from typing import Optional, Dict, Any + +from embedchain.utils import detect_datatype +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.loaders.base_loader import BaseLoader +from embedchain.loaders.local_text import LocalTextLoader +from embedchain.data_formatter.data_formatter import DataFormatter +from embedchain.config import AddConfig + + +@register_deserializable +class DirectoryLoader(BaseLoader): + """Load data from a directory.""" + + def __init__(self, config: Optional[Dict[str, Any]] = None): + super().__init__() + config = config or {} + self.recursive = config.get("recursive", True) + self.extensions = config.get("extensions", None) + self.errors = [] + + def load_data(self, path: str): + directory_path = Path(path) + if not directory_path.is_dir(): + raise ValueError(f"Invalid path: {path}") + + data_list = self._process_directory(directory_path) + doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest() + + for error in self.errors: + logging.warn(error) + + return {"doc_id": doc_id, "data": data_list} + + def _process_directory(self, directory_path: Path): + data_list = [] + for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"): + if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)): + loader = self._predict_loader(file_path) + data_list.extend(loader.load_data(str(file_path))["data"]) + return data_list + + def _predict_loader(self, file_path: Path) -> BaseLoader: + try: + data_type = detect_datatype(str(file_path)) + config = AddConfig() + return DataFormatter(data_type=data_type, config=config)._get_loader( + data_type=data_type, config=config.loader, loader=None + ) + except Exception as e: + self.errors.append(f"Error processing {file_path}: {e}") + return LocalTextLoader() diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 864ab95f..3b4ce6a0 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -35,6 +35,7 @@ class IndirectDataType(Enum): CUSTOM = "custom" RSSFEED = "rss_feed" BEEHIIV = "beehiiv" + DIRECTORY = "directory" class SpecialDataType(Enum): @@ -69,3 +70,4 @@ class DataType(Enum): CUSTOM = IndirectDataType.CUSTOM.value RSSFEED = IndirectDataType.RSSFEED.value BEEHIIV = IndirectDataType.BEEHIIV.value + DIRECTORY = IndirectDataType.DIRECTORY.value diff --git a/embedchain/utils.py b/embedchain/utils.py index ba6efd21..af8a57b8 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -196,8 +196,7 @@ def detect_datatype(source: Any) -> DataType: formatted_source = format_source(str(source), 30) if url: - from langchain.document_loaders.youtube import \ - ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS + from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS if url.netloc in YOUTUBE_ALLOWED_NETLOCS: logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") @@ -303,6 +302,14 @@ def detect_datatype(source: Any) -> DataType: logging.debug(f"Source of `{formatted_source}` detected as `mdx`.") return DataType.MDX + if source.endswith(".txt"): + logging.debug(f"Source of `{formatted_source}` detected as `text`.") + return DataType.TEXT + + if source.endswith(".pdf"): + logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") + return DataType.PDF_FILE + if source.endswith(".yaml"): with open(source, "r") as file: yaml_content = yaml.safe_load(file) diff --git a/tests/embedchain/test_utils.py b/tests/embedchain/test_utils.py index 819b1a14..18418acb 100644 --- a/tests/embedchain/test_utils.py +++ b/tests/embedchain/test_utils.py @@ -86,11 +86,9 @@ class TestApp(unittest.TestCase): @patch("os.path.isfile") def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile): - """Test error if a valid file is referenced, but it isn't a valid data_type""" with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp: mock_isfile.return_value = True - with self.assertRaises(ValueError): - detect_datatype(tmp.name) + self.assertEqual(detect_datatype(tmp.name), DataType.TEXT) def test_detect_datatype_regular_filesystem_no_file(self): """Test that if a filepath is not actually an existing file, it is not handled as a file path."""