[Feature] Improve GitHub loader (#962)

This commit is contained in:
Deven Patel
2023-11-16 22:06:36 -08:00
committed by GitHub
parent e0b73e6a5a
commit 023a61446f
2 changed files with 22 additions and 2 deletions

View File

@@ -3,6 +3,8 @@ import hashlib
import logging
import os
from tqdm import tqdm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.json import JSONLoader
from embedchain.loaders.mdx import MdxLoader
@@ -53,14 +55,24 @@ class GithubLoader(BaseLoader):
return data.get("data", [])
def _is_file_empty(file_path):
return os.path.getsize(file_path) == 0
def _is_whitelisted(file_path):
whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
_, file_extension = os.path.splitext(file_path)
return file_extension[1:] in whitelisted_extensions
def _add_repo_files(repo_path: str):
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_file = {
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
for root, _, files in os.walk(repo_path)
for filename in files
} # noqa: E501
for future in concurrent.futures.as_completed(future_to_file):
if _is_whitelisted(os.path.join(root, filename))
and not _is_file_empty(os.path.join(root, filename)) # noqa:E501
}
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
file = future_to_file[future]
try:
results = future.result()

View File

@@ -216,6 +216,10 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
return DataType.CSV
if url.path.endswith(".mdx") or url.path.endswith(".md"):
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
return DataType.MDX
if url.path.endswith(".docx"):
logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
return DataType.DOCX
@@ -292,6 +296,10 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
return DataType.XML
if source.endswith(".mdx") or source.endswith(".md"):
logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
return DataType.MDX
if source.endswith(".yaml"):
with open(source, "r") as file:
yaml_content = yaml.safe_load(file)