[Features] Add Github and Youtube Channel loaders (#957)
Co-authored-by: Deven Patel <deven298@yahoo.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
81
embedchain/loaders/github.py
Normal file
81
embedchain/loaders/github.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.json import JSONLoader
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
from embedchain.loaders.unstructured_file import UnstructuredLoader
|
||||
from embedchain.utils import detect_datatype
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
def load_data(self, repo_url):
|
||||
"""Load data from a git repo."""
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
|
||||
) from e
|
||||
|
||||
mdx_loader = MdxLoader()
|
||||
json_loader = JSONLoader()
|
||||
unstructured_loader = UnstructuredLoader()
|
||||
data = []
|
||||
data_urls = []
|
||||
|
||||
def _fetch_or_clone_repo(repo_url: str, local_path: str):
|
||||
if os.path.exists(local_path):
|
||||
logging.info("Repository already exists. Fetching updates...")
|
||||
repo = Repo(local_path)
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
logging.info("Fetch completed.")
|
||||
else:
|
||||
logging.info("Cloning repository...")
|
||||
Repo.clone_from(repo_url, local_path)
|
||||
logging.info("Clone completed.")
|
||||
|
||||
def _load_file(file_path: str):
|
||||
try:
|
||||
data_type = detect_datatype(file_path).value
|
||||
except Exception:
|
||||
data_type = "unstructured"
|
||||
|
||||
if data_type == "mdx":
|
||||
data = mdx_loader.load_data(file_path)
|
||||
elif data_type == "json":
|
||||
data = json_loader.load_data(file_path)
|
||||
else:
|
||||
data = unstructured_loader.load_data(file_path)
|
||||
|
||||
return data.get("data", [])
|
||||
|
||||
def _add_repo_files(repo_path: str):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_file = {
|
||||
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
|
||||
for root, _, files in os.walk(repo_path)
|
||||
for filename in files
|
||||
} # noqa: E501
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file = future_to_file[future]
|
||||
try:
|
||||
results = future.result()
|
||||
if results:
|
||||
data.extend(results)
|
||||
data_urls.extend([result.get("meta_data").get("url") for result in results])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process {file}: {e}")
|
||||
|
||||
source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
||||
repo_path = f"/tmp/{source_hash}"
|
||||
_fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
|
||||
_add_repo_files(repo_path)
|
||||
doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
Reference in New Issue
Block a user