[Improvement] improve github loader (#984)
This commit is contained in:
@@ -64,7 +64,6 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
|
||||
DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
|
||||
DataType.SUBSTACK: "embedchain.loaders.substack.SubstackLoader",
|
||||
DataType.GITHUB: "embedchain.loaders.github.GithubLoader",
|
||||
DataType.YOUTUBE_CHANNEL: "embedchain.loaders.youtube_channel.YoutubeChannelLoader",
|
||||
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
|
||||
}
|
||||
@@ -75,6 +74,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.MYSQL,
|
||||
DataType.SLACK,
|
||||
DataType.DISCOURSE,
|
||||
DataType.GITHUB,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -2,116 +2,297 @@ import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.json import JSONLoader
|
||||
from embedchain.loaders.mdx import MdxLoader
|
||||
from embedchain.utils import detect_datatype
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
GITHUB_URL = "https://github.com"
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
|
||||
def _load_file_data(path):
|
||||
data = []
|
||||
data_content = []
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
content = f.read().decode("utf-8")
|
||||
except Exception as e:
|
||||
print(f"Error reading file {path}: {e}")
|
||||
raise ValueError(f"Failed to read file {path}")
|
||||
|
||||
meta_data = {}
|
||||
meta_data["url"] = path
|
||||
data.append(
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": meta_data,
|
||||
}
|
||||
)
|
||||
data_content.append(content)
|
||||
doc_id = hashlib.sha256((" ".join(data_content) + path).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
def load_data(self, repo_url):
|
||||
"""Load data from a git repo."""
|
||||
"""Load data from github search query."""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
|
||||
)
|
||||
|
||||
try:
|
||||
from git import Repo
|
||||
from github import Github
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[git]'`"
|
||||
"GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`"
|
||||
) from e
|
||||
|
||||
mdx_loader = MdxLoader()
|
||||
json_loader = JSONLoader()
|
||||
data = []
|
||||
data_urls = []
|
||||
self.config = config
|
||||
token = config.get("token")
|
||||
if not token:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
|
||||
)
|
||||
|
||||
try:
|
||||
self.client = Github(token)
|
||||
except Exception as e:
|
||||
logging.error(f"GithubLoader failed to initialize client: {e}")
|
||||
self.client = None
|
||||
|
||||
def _github_search_code(self, query: str):
|
||||
"""Search github code."""
|
||||
data = []
|
||||
results = self.client.search_code(query)
|
||||
for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
|
||||
url = result.html_url
|
||||
logging.info(f"Added data from url: {url}")
|
||||
content = result.decoded_content.decode("utf-8")
|
||||
metadata = {
|
||||
"url": url,
|
||||
}
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(content),
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
def _get_github_repo_data(self, repo_url: str):
|
||||
local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
||||
local_path = f"/tmp/{local_hash}"
|
||||
data = []
|
||||
|
||||
def _get_repo_tree(repo_url: str, local_path: str):
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501
|
||||
) from e
|
||||
|
||||
def _fetch_or_clone_repo(repo_url: str, local_path: str):
|
||||
if os.path.exists(local_path):
|
||||
logging.info("Repository already exists. Fetching updates...")
|
||||
repo = Repo(local_path)
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
logging.info("Fetch completed.")
|
||||
else:
|
||||
logging.info("Cloning repository...")
|
||||
Repo.clone_from(repo_url, local_path)
|
||||
repo = Repo.clone_from(repo_url, local_path)
|
||||
logging.info("Clone completed.")
|
||||
return repo.head.commit.tree
|
||||
|
||||
def _load_file(file_path: str):
|
||||
try:
|
||||
data_type = detect_datatype(file_path).value
|
||||
except Exception:
|
||||
data_type = "unstructured"
|
||||
|
||||
if data_type == "mdx":
|
||||
data = mdx_loader.load_data(file_path)
|
||||
elif data_type == "json":
|
||||
data = json_loader.load_data(file_path)
|
||||
else:
|
||||
data = _load_file_data(file_path)
|
||||
|
||||
return data.get("data", [])
|
||||
|
||||
def _is_file_empty(file_path):
|
||||
return os.path.getsize(file_path) == 0
|
||||
|
||||
def _is_whitelisted(file_path):
|
||||
whitelisted_extensions = ["md", "txt", "html", "json", "py", "js", "jsx", "ts", "tsx", "mdx", "rst"]
|
||||
_, file_extension = os.path.splitext(file_path)
|
||||
return file_extension[1:] in whitelisted_extensions
|
||||
|
||||
def _add_repo_files(repo_path: str):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(_load_file, os.path.join(root, filename)): os.path.join(root, filename)
|
||||
for root, _, files in os.walk(repo_path)
|
||||
for filename in files
|
||||
if _is_whitelisted(os.path.join(root, filename))
|
||||
and not _is_file_empty(os.path.join(root, filename)) # noqa:E501
|
||||
}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_file), total=len(future_to_file)):
|
||||
file = future_to_file[future]
|
||||
def _get_repo_tree_contents(repo_path, tree, progress_bar):
|
||||
for subtree in tree:
|
||||
if subtree.type == "tree":
|
||||
_get_repo_tree_contents(repo_path, subtree, progress_bar)
|
||||
else:
|
||||
assert subtree.type == "blob"
|
||||
try:
|
||||
results = future.result()
|
||||
if results:
|
||||
data.extend(results)
|
||||
data_urls.extend([result.get("meta_data").get("url") for result in results])
|
||||
except Exception as e:
|
||||
logging.warn(f"Failed to process {file}: {e}")
|
||||
contents = subtree.data_stream.read().decode("utf-8")
|
||||
except Exception:
|
||||
logging.warning(f"Failed to read file: {subtree.path}")
|
||||
progress_bar.update(1) if progress_bar else None
|
||||
continue
|
||||
|
||||
url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(contents),
|
||||
"meta_data": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
)
|
||||
if progress_bar is not None:
|
||||
progress_bar.update(1)
|
||||
|
||||
repo_tree = _get_repo_tree(repo_url, local_path)
|
||||
tree_list = list(repo_tree.traverse())
|
||||
with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
|
||||
_get_repo_tree_contents(local_path, repo_tree, progress_bar)
|
||||
|
||||
return data
|
||||
|
||||
def _github_search_repo(self, query: str):
|
||||
"""Search github repo."""
|
||||
data = []
|
||||
logging.info(f"Searching github repos with query: {query}")
|
||||
results = self.client.search_repositories(query)
|
||||
# Add repo urls and descriptions
|
||||
urls = list(map(lambda x: x.html_url, results))
|
||||
discriptions = list(map(lambda x: x.description, results))
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(desc),
|
||||
"meta_data": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
for url, desc in zip(urls, discriptions)
|
||||
)
|
||||
|
||||
# Add repo contents
|
||||
for result in results:
|
||||
clone_url = result.clone_url
|
||||
logging.info(f"Cloning repository: {clone_url}")
|
||||
data = self._get_github_repo_data(clone_url)
|
||||
return data
|
||||
|
||||
def _github_search_issues_and_pr(self, query: str, type: str):
|
||||
"""Search github issues and PRs."""
|
||||
data = []
|
||||
|
||||
query = f"{query} is:{type}"
|
||||
logging.info(f"Searching github for query: {query}")
|
||||
|
||||
results = self.client.search_issues(query)
|
||||
|
||||
logging.info(f"Total results: {results.totalCount}")
|
||||
for result in tqdm(results, total=results.totalCount, desc=f"Loading {type} from github"):
|
||||
url = result.html_url
|
||||
title = result.title
|
||||
body = result.body
|
||||
if not body:
|
||||
logging.warn(f"Skipping issue because empty content for: {url}")
|
||||
continue
|
||||
labels = " ".join([label.name for label in result.labels])
|
||||
issue_comments = result.get_comments()
|
||||
comments = []
|
||||
comments_created_at = []
|
||||
for comment in issue_comments:
|
||||
comments_created_at.append(str(comment.created_at))
|
||||
comments.append(f"{comment.user.name}:{comment.body}")
|
||||
content = "\n".join([title, labels, body, *comments])
|
||||
metadata = {
|
||||
"url": url,
|
||||
"created_at": str(result.created_at),
|
||||
"comments_created_at": " ".join(comments_created_at),
|
||||
}
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(content),
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
# need to test more for discussion
|
||||
def _github_search_discussions(self, query: str):
|
||||
"""Search github discussions."""
|
||||
data = []
|
||||
|
||||
query = f"{query} is:discussion"
|
||||
logging.info(f"Searching github repo for query: {query}")
|
||||
repos_results = self.client.search_repositories(query)
|
||||
logging.info(f"Total repos found: {repos_results.totalCount}")
|
||||
for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
|
||||
teams = repo_result.get_teams()
|
||||
# import pdb; pdb.set_trace()
|
||||
for team in teams:
|
||||
team_discussions = team.get_discussions()
|
||||
for discussion in team_discussions:
|
||||
url = discussion.html_url
|
||||
title = discussion.title
|
||||
body = discussion.body
|
||||
if not body:
|
||||
logging.warn(f"Skipping discussion because empty content for: {url}")
|
||||
continue
|
||||
comments = []
|
||||
comments_created_at = []
|
||||
print("Discussion comments: ", discussion.comments_url)
|
||||
content = "\n".join([title, body, *comments])
|
||||
metadata = {
|
||||
"url": url,
|
||||
"created_at": str(discussion.created_at),
|
||||
"comments_created_at": " ".join(comments_created_at),
|
||||
}
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(content),
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
def _search_github_data(self, search_type: str, query: str):
|
||||
"""Search github data."""
|
||||
if search_type == "code":
|
||||
data = self._github_search_code(query)
|
||||
elif search_type == "repo":
|
||||
data = self._github_search_repo(query)
|
||||
elif search_type == "issue":
|
||||
data = self._github_search_issues_and_pr(query, search_type)
|
||||
elif search_type == "pr":
|
||||
data = self._github_search_issues_and_pr(query, search_type)
|
||||
elif search_type == "discussion":
|
||||
raise ValueError("GithubLoader does not support searching discussions yet.")
|
||||
|
||||
return data
|
||||
|
||||
def _get_valid_github_query(self, query: str):
|
||||
"""Check if query is valid and return search types and valid github query."""
|
||||
query_terms = shlex.split(query)
|
||||
# query must provide repo to load data from
|
||||
if len(query_terms) < 1 or "repo:" not in query:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a search query with `repo:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
|
||||
)
|
||||
|
||||
github_query = []
|
||||
types = set()
|
||||
type_pattern = r"type:([a-zA-Z,]+)"
|
||||
for term in query_terms:
|
||||
term_match = re.search(type_pattern, term)
|
||||
if term_match:
|
||||
search_types = term_match.group(1).split(",")
|
||||
types.update(search_types)
|
||||
else:
|
||||
github_query.append(term)
|
||||
|
||||
# query must provide search type
|
||||
if len(types) == 0:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a search query with `type:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
|
||||
)
|
||||
|
||||
for search_type in search_types:
|
||||
if search_type not in VALID_SEARCH_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid search type: {search_type}. Valid types are: {', '.join(VALID_SEARCH_TYPES)}"
|
||||
)
|
||||
|
||||
query = " ".join(github_query)
|
||||
|
||||
return types, query
|
||||
|
||||
def load_data(self, search_query: str, max_results: int = 1000):
|
||||
"""Load data from github search query."""
|
||||
|
||||
if not self.client:
|
||||
raise ValueError(
|
||||
"GithubLoader client is not initialized, data will not be loaded. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
|
||||
)
|
||||
|
||||
search_types, query = self._get_valid_github_query(search_query)
|
||||
logging.info(f"Searching github for query: {query}, with types: {', '.join(search_types)}")
|
||||
|
||||
data = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures_map = executor.map(self._search_github_data, search_types, [query] * len(search_types))
|
||||
for search_data in tqdm(futures_map, total=len(search_types), desc="Searching data from github"):
|
||||
data.extend(search_data)
|
||||
|
||||
source_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
||||
repo_path = f"/tmp/{source_hash}"
|
||||
_fetch_or_clone_repo(repo_url=repo_url, local_path=repo_path)
|
||||
_add_repo_files(repo_path)
|
||||
doc_id = hashlib.sha256((repo_url + ", ".join(data_urls)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"doc_id": hashlib.sha256(query.encode()).hexdigest(),
|
||||
"data": data,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user