118 lines
4.1 KiB
Python
118 lines
4.1 KiB
Python
import concurrent.futures
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
|
|
from tqdm import tqdm
|
|
|
|
from embedchain.loaders.base_loader import BaseLoader
|
|
from embedchain.loaders.json import JSONLoader
|
|
from embedchain.loaders.mdx import MdxLoader
|
|
from embedchain.utils import detect_datatype
|
|
|
|
|
|
def _load_file_data(path):
|
|
data = []
|
|
data_content = []
|
|
try:
|
|
with open(path, "rb") as f:
|
|
content = f.read().decode("utf-8")
|
|
except Exception as e:
|
|
print(f"Error reading file {path}: {e}")
|
|
raise ValueError(f"Failed to read file {path}")
|
|
|
|
meta_data = {}
|
|
meta_data["url"] = path
|
|
data.append(
|
|
{
|
|
"content": content,
|
|
"meta_data": meta_data,
|
|
}
|
|
)
|
|
data_content.append(content)
|
|
doc_id = hashlib.sha256((" ".join(data_content) + path).encode()).hexdigest()
|
|
return {
|
|
"doc_id": doc_id,
|
|
"data": data,
|
|
}
|
|
|
|
|
|
class GithubLoader(BaseLoader):
|
|
def load_data(self, repo_url):
|
|
"""Load data from a git repo."""
|
|
try:
|
|
from git import Repo
|
|
except ImportError as e:
|
|
raise ValueError(
|
|
"GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
|
|
) from e
|
|
|
|
mdx_loader = MdxLoader()
|
|
json_loader = JSONLoader()
|
|
data = []
|
|
data_urls = []
|
|
|
|
def _fetch_or_clone_repo(repo_url: str, local_path: str):
|
|
if os.path.exists(local_path):
|
|
logging.info("Repository already exists. Fetching updates...")
|
|
repo = Repo(local_path)
|
|
origin = repo.remotes.origin
|
|
origin.fetch()
|
|
logging.info("Fetch completed.")
|
|
else:
|
|
logging.info("Cloning repository...")
|
|
Repo.clone_from(repo_url, local_path)
|
|
logging.info("Clone completed.")
|
|
|
|
def _load_file(file_path: str):
|
|
try:
|
|
data_type = detect_datatype(file_path).value
|
|
except Exception:
|
|
data_type = "unstructured"
|
|
|
|
if data_type == "mdx":
|
|
data = mdx_loader.load_data(file_path)
|
|
elif data_type == "json":
|
|
data = json_loader.load_data(file_path)
|
|
else:
|
|
data = _load_file_data(file_path)
|
|
|
|
return data.get("data", [])
|
|
|
|
def _is_file_empty(file_path):
|
|
return os.path.getsize(file_path) == 0
|
|
|
|
def _is_whitelisted(file_path):
|
|
whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
|
|
_, file_extension = os.path.splitext(file_path)
|
|
return file_extension[1:] in whitelisted_extensions
|
|
|
|
def _add_repo_files(repo_path: str):
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
|
future_to_file = {
|
|
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
|
|
for root, _, files in os.walk(repo_path)
|
|
for filename in files
|
|
if _is_whitelisted(os.path.join(root, filename))
|
|
and not _is_file_empty(os.path.join(root, filename)) # noqa:E501
|
|
}
|
|
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
|
|
file = future_to_file[future]
|
|
try:
|
|
results = future.result()
|
|
if results:
|
|
data.extend(results)
|
|
data_urls.extend([result.get("meta_data").get("url") for result in results])
|
|
except Exception as e:
|
|
logging.warn(f"Failed to process {file}: {e}")
|
|
|
|
source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
|
repo_path = f"/tmp/{source_hash}"
|
|
_fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
|
|
_add_repo_files(repo_path)
|
|
doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
|
|
return {
|
|
"doc_id": doc_id,
|
|
"data": data,
|
|
}
|