[Feature] Add support for directory loader as data source (#1008)
This commit is contained in:
@@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
|
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
|
||||||
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
|
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
|
||||||
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
|
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
|
||||||
|
DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
|
||||||
}
|
}
|
||||||
|
|
||||||
if data_type == DataType.CUSTOM or loader is not None:
|
if data_type == DataType.CUSTOM or loader is not None:
|
||||||
@@ -116,6 +117,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
|
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||||
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
|
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
|
||||||
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
|
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
|
||||||
|
DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||||
}
|
}
|
||||||
|
|
||||||
if chunker is not None:
|
if chunker is not None:
|
||||||
|
|||||||
55
embedchain/loaders/directory_loader.py
Normal file
55
embedchain/loaders/directory_loader.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from embedchain.utils import detect_datatype
|
||||||
|
from embedchain.helpers.json_serializable import register_deserializable
|
||||||
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
from embedchain.loaders.local_text import LocalTextLoader
|
||||||
|
from embedchain.data_formatter.data_formatter import DataFormatter
|
||||||
|
from embedchain.config import AddConfig
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class DirectoryLoader(BaseLoader):
|
||||||
|
"""Load data from a directory."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__()
|
||||||
|
config = config or {}
|
||||||
|
self.recursive = config.get("recursive", True)
|
||||||
|
self.extensions = config.get("extensions", None)
|
||||||
|
self.errors = []
|
||||||
|
|
||||||
|
def load_data(self, path: str):
|
||||||
|
directory_path = Path(path)
|
||||||
|
if not directory_path.is_dir():
|
||||||
|
raise ValueError(f"Invalid path: {path}")
|
||||||
|
|
||||||
|
data_list = self._process_directory(directory_path)
|
||||||
|
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
|
||||||
|
|
||||||
|
for error in self.errors:
|
||||||
|
logging.warn(error)
|
||||||
|
|
||||||
|
return {"doc_id": doc_id, "data": data_list}
|
||||||
|
|
||||||
|
def _process_directory(self, directory_path: Path):
|
||||||
|
data_list = []
|
||||||
|
for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"):
|
||||||
|
if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)):
|
||||||
|
loader = self._predict_loader(file_path)
|
||||||
|
data_list.extend(loader.load_data(str(file_path))["data"])
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
def _predict_loader(self, file_path: Path) -> BaseLoader:
|
||||||
|
try:
|
||||||
|
data_type = detect_datatype(str(file_path))
|
||||||
|
config = AddConfig()
|
||||||
|
return DataFormatter(data_type=data_type, config=config)._get_loader(
|
||||||
|
data_type=data_type, config=config.loader, loader=None
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.errors.append(f"Error processing {file_path}: {e}")
|
||||||
|
return LocalTextLoader()
|
||||||
@@ -35,6 +35,7 @@ class IndirectDataType(Enum):
|
|||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
RSSFEED = "rss_feed"
|
RSSFEED = "rss_feed"
|
||||||
BEEHIIV = "beehiiv"
|
BEEHIIV = "beehiiv"
|
||||||
|
DIRECTORY = "directory"
|
||||||
|
|
||||||
|
|
||||||
class SpecialDataType(Enum):
|
class SpecialDataType(Enum):
|
||||||
@@ -69,3 +70,4 @@ class DataType(Enum):
|
|||||||
CUSTOM = IndirectDataType.CUSTOM.value
|
CUSTOM = IndirectDataType.CUSTOM.value
|
||||||
RSSFEED = IndirectDataType.RSSFEED.value
|
RSSFEED = IndirectDataType.RSSFEED.value
|
||||||
BEEHIIV = IndirectDataType.BEEHIIV.value
|
BEEHIIV = IndirectDataType.BEEHIIV.value
|
||||||
|
DIRECTORY = IndirectDataType.DIRECTORY.value
|
||||||
|
|||||||
@@ -196,8 +196,7 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
formatted_source = format_source(str(source), 30)
|
formatted_source = format_source(str(source), 30)
|
||||||
|
|
||||||
if url:
|
if url:
|
||||||
from langchain.document_loaders.youtube import \
|
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
|
||||||
|
|
||||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||||
@@ -303,6 +302,14 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
||||||
return DataType.MDX
|
return DataType.MDX
|
||||||
|
|
||||||
|
if source.endswith(".txt"):
|
||||||
|
logging.debug(f"Source of `{formatted_source}` detected as `text`.")
|
||||||
|
return DataType.TEXT
|
||||||
|
|
||||||
|
if source.endswith(".pdf"):
|
||||||
|
logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
|
||||||
|
return DataType.PDF_FILE
|
||||||
|
|
||||||
if source.endswith(".yaml"):
|
if source.endswith(".yaml"):
|
||||||
with open(source, "r") as file:
|
with open(source, "r") as file:
|
||||||
yaml_content = yaml.safe_load(file)
|
yaml_content = yaml.safe_load(file)
|
||||||
|
|||||||
@@ -86,11 +86,9 @@ class TestApp(unittest.TestCase):
|
|||||||
|
|
||||||
@patch("os.path.isfile")
|
@patch("os.path.isfile")
|
||||||
def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile):
|
def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile):
|
||||||
"""Test error if a valid file is referenced, but it isn't a valid data_type"""
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
|
with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp:
|
||||||
mock_isfile.return_value = True
|
mock_isfile.return_value = True
|
||||||
with self.assertRaises(ValueError):
|
self.assertEqual(detect_datatype(tmp.name), DataType.TEXT)
|
||||||
detect_datatype(tmp.name)
|
|
||||||
|
|
||||||
def test_detect_datatype_regular_filesystem_no_file(self):
|
def test_detect_datatype_regular_filesystem_no_file(self):
|
||||||
"""Test that if a filepath is not actually an existing file, it is not handled as a file path."""
|
"""Test that if a filepath is not actually an existing file, it is not handled as a file path."""
|
||||||
|
|||||||
Reference in New Issue
Block a user