[Feature] Improve GitHub loader (#962)
This commit is contained in:
@@ -3,6 +3,8 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.json import JSONLoader
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
@@ -53,14 +55,24 @@ class GithubLoader(BaseLoader):
|
||||
|
||||
return data.get("data", [])
|
||||
|
||||
def _is_file_empty(file_path):
|
||||
return os.path.getsize(file_path) == 0
|
||||
|
||||
def _is_whitelisted(file_path):
|
||||
whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
|
||||
_, file_extension = os.path.splitext(file_path)
|
||||
return file_extension[1:] in whitelisted_extensions
|
||||
|
||||
def _add_repo_files(repo_path: str):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_file = {
|
||||
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
|
||||
for root, _, files in os.walk(repo_path)
|
||||
for filename in files
|
||||
} # noqa: E501
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
if _is_whitelisted(os.path.join(root, filename))
|
||||
and not _is_file_empty(os.path.join(root, filename)) # noqa:E501
|
||||
}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
|
||||
file = future_to_file[future]
|
||||
try:
|
||||
results = future.result()
|
||||
|
||||
@@ -216,6 +216,10 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||
return DataType.CSV
|
||||
|
||||
if url.path.endswith(".mdx") or url.path.endswith(".md"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
||||
return DataType.MDX
|
||||
|
||||
if url.path.endswith(".docx"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
|
||||
return DataType.DOCX
|
||||
@@ -292,6 +296,10 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
|
||||
return DataType.XML
|
||||
|
||||
if source.endswith(".mdx") or source.endswith(".md"):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
|
||||
return DataType.MDX
|
||||
|
||||
if source.endswith(".yaml"):
|
||||
with open(source, "r") as file:
|
||||
yaml_content = yaml.safe_load(file)
|
||||
|
||||
Reference in New Issue
Block a user