[Features] Add Github and Youtube Channel loaders (#957)
Co-authored-by: Deven Patel <deven298@yahoo.com> Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import hashlib
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
@@ -7,7 +10,15 @@ from embedchain.models.data_type import DataType
|
||||
class BaseChunker(JSONSerializable):
|
||||
def __init__(self, text_splitter):
|
||||
"""Initialize the chunker."""
|
||||
self.text_splitter = text_splitter
|
||||
if text_splitter is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
else:
|
||||
self.text_splitter = text_splitter
|
||||
self.data_type = None
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None):
|
||||
|
||||
@@ -64,6 +64,8 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
|
||||
DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
|
||||
DataType.SUBSTACK: "embedchain.loaders.substack.SubstackLoader",
|
||||
DataType.GITHUB: "embedchain.loaders.github.GithubLoader",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.loaders.youtube_channel.YoutubeChannelLoader",
|
||||
}
|
||||
|
||||
custom_loaders = set(
|
||||
@@ -114,6 +116,8 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
|
||||
DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
|
||||
DataType.SUBSTACK: "embedchain.chunkers.substack.SubstackChunker",
|
||||
DataType.GITHUB: "embedchain.chunkers.base_chunker.BaseChunker",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.base_chunker.BaseChunker",
|
||||
}
|
||||
|
||||
if data_type in chunker_classes:
|
||||
|
||||
81
embedchain/loaders/github.py
Normal file
81
embedchain/loaders/github.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.json import JSONLoader
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
from embedchain.loaders.unstructured_file import UnstructuredLoader
|
||||
from embedchain.utils import detect_datatype
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
def load_data(self, repo_url):
|
||||
"""Load data from a git repo."""
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
|
||||
) from e
|
||||
|
||||
mdx_loader = MdxLoader()
|
||||
json_loader = JSONLoader()
|
||||
unstructured_loader = UnstructuredLoader()
|
||||
data = []
|
||||
data_urls = []
|
||||
|
||||
def _fetch_or_clone_repo(repo_url: str, local_path: str):
|
||||
if os.path.exists(local_path):
|
||||
logging.info("Repository already exists. Fetching updates...")
|
||||
repo = Repo(local_path)
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
logging.info("Fetch completed.")
|
||||
else:
|
||||
logging.info("Cloning repository...")
|
||||
Repo.clone_from(repo_url, local_path)
|
||||
logging.info("Clone completed.")
|
||||
|
||||
def _load_file(file_path: str):
|
||||
try:
|
||||
data_type = detect_datatype(file_path).value
|
||||
except Exception:
|
||||
data_type = "unstructured"
|
||||
|
||||
if data_type == "mdx":
|
||||
data = mdx_loader.load_data(file_path)
|
||||
elif data_type == "json":
|
||||
data = json_loader.load_data(file_path)
|
||||
else:
|
||||
data = unstructured_loader.load_data(file_path)
|
||||
|
||||
return data.get("data", [])
|
||||
|
||||
def _add_repo_files(repo_path: str):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_file = {
|
||||
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
|
||||
for root, _, files in os.walk(repo_path)
|
||||
for filename in files
|
||||
} # noqa: E501
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file = future_to_file[future]
|
||||
try:
|
||||
results = future.result()
|
||||
if results:
|
||||
data.extend(results)
|
||||
data_urls.extend([result.get("meta_data").get("url") for result in results])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process {file}: {e}")
|
||||
|
||||
source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
||||
repo_path = f"/tmp/{source_hash}"
|
||||
_fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
|
||||
_add_repo_files(repo_path)
|
||||
doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
@@ -57,8 +57,8 @@ class SitemapLoader(BaseLoader):
|
||||
try:
|
||||
data = future.result()
|
||||
if data:
|
||||
output.append(data)
|
||||
output.extend(data)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading page {link}: {e}")
|
||||
|
||||
return {"doc_id": doc_id, "data": [data[0] for data in output if data]}
|
||||
return {"doc_id": doc_id, "data": output}
|
||||
|
||||
70
embedchain/loaders/youtube_channel.py
Normal file
70
embedchain/loaders/youtube_channel.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for youtube channel."""
|
||||
|
||||
def load_data(self, channel_name):
|
||||
try:
|
||||
import yt_dlp
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"YoutubeLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[youtube_channel]'`" # noqa: E501
|
||||
) from e
|
||||
|
||||
data = []
|
||||
data_urls = []
|
||||
youtube_url = f"https://www.youtube.com/{channel_name}/videos"
|
||||
youtube_video_loader = YoutubeVideoLoader()
|
||||
|
||||
def _get_yt_video_links():
|
||||
try:
|
||||
ydl_opts = {
|
||||
"quiet": True,
|
||||
"extract_flat": True,
|
||||
}
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
info_dict = ydl.extract_info(youtube_url, download=False)
|
||||
if "entries" in info_dict:
|
||||
videos = [entry["url"] for entry in info_dict["entries"]]
|
||||
return videos
|
||||
except Exception:
|
||||
logging.error(f"Failed to fetch youtube videos for channel: {channel_name}")
|
||||
return []
|
||||
|
||||
def _load_yt_video(video_link):
|
||||
try:
|
||||
each_load_data = youtube_video_loader.load_data(video_link)
|
||||
if each_load_data:
|
||||
return each_load_data.get("data")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load youtube video {video_link}: {e}")
|
||||
return None
|
||||
|
||||
def _add_youtube_channel():
|
||||
video_links = _get_yt_video_links()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_video = {
|
||||
executor.submit(_load_yt_video, video_link): video_link for video_link in video_links
|
||||
} # noqa: E501
|
||||
for future in concurrent.futures.as_completed(future_to_video):
|
||||
video = future_to_video[future]
|
||||
try:
|
||||
results = future.result()
|
||||
if results:
|
||||
data.extend(results)
|
||||
data_urls.extend([result.get("meta_data").get("url") for result in results])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process youtube video {video}: {e}")
|
||||
|
||||
_add_youtube_channel()
|
||||
doc_id = hashlib.sha256((youtube_url + ", ".join(data_urls)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
@@ -34,6 +34,8 @@ class IndirectDataType(Enum):
|
||||
SLACK = "slack"
|
||||
DISCOURSE = "discourse"
|
||||
SUBSTACK = "substack"
|
||||
GITHUB = "github"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
|
||||
class SpecialDataType(Enum):
|
||||
@@ -67,3 +69,5 @@ class DataType(Enum):
|
||||
SLACK = IndirectDataType.SLACK.value
|
||||
DISCOURSE = IndirectDataType.DISCOURSE.value
|
||||
SUBSTACK = IndirectDataType.SUBSTACK.value
|
||||
GITHUB = IndirectDataType.GITHUB.value
|
||||
YOUTUBE_CHANNEL = IndirectDataType.YOUTUBE_CHANNEL.value
|
||||
|
||||
@@ -255,6 +255,10 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
|
||||
return DataType.DOCS_SITE
|
||||
|
||||
if "github.com" in url.netloc:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `github`.")
|
||||
return DataType.GITHUB
|
||||
|
||||
# If none of the above conditions are met, it's a general web page
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
|
||||
return DataType.WEB_PAGE
|
||||
|
||||
Reference in New Issue
Block a user