diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 6d3ce57e..3f9a2edd 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -76,6 +76,7 @@ class DataFormatter(JSONSerializable): DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader", DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader", DataType.SLACK: "embedchain.loaders.slack.SlackLoader", + DataType.TEXT_FILE: "embedchain.loaders.text_file.TextFileLoader", } if data_type == DataType.CUSTOM or loader is not None: @@ -120,6 +121,7 @@ class DataFormatter(JSONSerializable): DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker", DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker", DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker", + DataType.TEXT_FILE: "embedchain.chunkers.common_chunker.CommonChunker", } if chunker is not None: diff --git a/embedchain/loaders/directory_loader.py b/embedchain/loaders/directory_loader.py index bcc225f5..9a4bc48c 100644 --- a/embedchain/loaders/directory_loader.py +++ b/embedchain/loaders/directory_loader.py @@ -7,7 +7,7 @@ from embedchain.config import AddConfig from embedchain.data_formatter.data_formatter import DataFormatter from embedchain.helpers.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader -from embedchain.loaders.local_text import LocalTextLoader +from embedchain.loaders.text_file import TextFileLoader from embedchain.utils import detect_datatype @@ -58,4 +58,4 @@ class DirectoryLoader(BaseLoader): ) except Exception as e: self.errors.append(f"Error processing {file_path}: {e}") - return LocalTextLoader() + return TextFileLoader() diff --git a/embedchain/loaders/text_file.py b/embedchain/loaders/text_file.py new file mode 100644 index 00000000..1c9e02db --- /dev/null +++ b/embedchain/loaders/text_file.py @@ -0,0 +1,30 @@ +import hashlib +import os + +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.loaders.base_loader import BaseLoader + + +@register_deserializable +class TextFileLoader(BaseLoader): + def load_data(self, url: str): + """Load data from a text file located at a local path.""" + if not os.path.exists(url): + raise FileNotFoundError(f"The file at {url} does not exist.") + + with open(url, "r", encoding="utf-8") as file: + content = file.read() + + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + + meta_data = {"url": url, "file_size": os.path.getsize(url), "file_type": url.split(".")[-1]} + + return { + "doc_id": doc_id, + "data": [ + { + "content": content, + "meta_data": meta_data, + } + ], + } diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 93ed71f4..3ee68c84 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -37,6 +37,7 @@ class IndirectDataType(Enum): BEEHIIV = "beehiiv" DIRECTORY = "directory" SLACK = "slack" + TEXT_FILE = "text_file" class SpecialDataType(Enum): @@ -73,3 +74,4 @@ class DataType(Enum): BEEHIIV = IndirectDataType.BEEHIIV.value DIRECTORY = IndirectDataType.DIRECTORY.value SLACK = IndirectDataType.SLACK.value + TEXT_FILE = IndirectDataType.TEXT_FILE.value diff --git a/embedchain/utils.py b/embedchain/utils.py index 0ebb2349..2cfe57a5 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -305,7 +305,7 @@ def detect_datatype(source: Any) -> DataType: if source.endswith(".txt"): logging.debug(f"Source of `{formatted_source}` detected as `text`.") - return DataType.TEXT + return DataType.TEXT_FILE if source.endswith(".pdf"): logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.") @@ -331,6 +331,10 @@ def detect_datatype(source: Any) -> DataType: logging.debug(f"Source of `{formatted_source}` detected as `json`.") return DataType.JSON + if os.path.exists(source) and is_readable(open(source).read()): + logging.debug(f"Source of `{formatted_source}` detected as `text_file`.") + return DataType.TEXT_FILE + # If the source is a valid file, that's not detectable as a type, an error is raised. # It does not fallback to text. raise ValueError( diff --git a/tests/embedchain/test_utils.py b/tests/embedchain/test_utils.py index 18418acb..80365abd 100644 --- a/tests/embedchain/test_utils.py +++ b/tests/embedchain/test_utils.py @@ -85,10 +85,10 @@ class TestApp(unittest.TestCase): detect_datatype(["foo", "bar"]) @patch("os.path.isfile") - def test_detect_datatype_regular_filesystem_file_not_detected(self, mock_isfile): + def test_detect_datatype_regular_filesystem_file_txt(self, mock_isfile): with tempfile.NamedTemporaryFile(suffix=".txt", delete=True) as tmp: mock_isfile.return_value = True - self.assertEqual(detect_datatype(tmp.name), DataType.TEXT) + self.assertEqual(detect_datatype(tmp.name), DataType.TEXT_FILE) def test_detect_datatype_regular_filesystem_no_file(self): """Test that if a filepath is not actually an existing file, it is not handled as a file path."""