Add folder and branch to GitHub (#1308)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
from typing import Any, Optional
|
||||
@@ -14,7 +13,7 @@ from embedchain.utils.misc import clean_string
|
||||
GITHUB_URL = "https://github.com"
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
|
||||
VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
|
||||
VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"])
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
@@ -66,85 +65,56 @@ class GithubLoader(BaseLoader):
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _get_github_repo_data(repo_url: str):
|
||||
local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
||||
local_path = f"/tmp/{local_hash}"
|
||||
def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]:
|
||||
"""Get file contents from Repo"""
|
||||
data = []
|
||||
|
||||
def _get_repo_tree(repo_url: str, local_path: str):
|
||||
try:
|
||||
from git import Repo
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501
|
||||
) from e
|
||||
repo = self.client.get_repo(repo_name)
|
||||
repo_contents = repo.get_contents("")
|
||||
|
||||
if os.path.exists(local_path):
|
||||
logging.info("Repository already exists. Fetching updates...")
|
||||
repo = Repo(local_path)
|
||||
logging.info("Fetch completed.")
|
||||
else:
|
||||
logging.info("Cloning repository...")
|
||||
repo = Repo.clone_from(repo_url, local_path)
|
||||
logging.info("Clone completed.")
|
||||
return repo.head.commit.tree
|
||||
if branch_name:
|
||||
repo_contents = repo.get_contents("", ref=branch_name)
|
||||
if file_path:
|
||||
repo_contents = [repo.get_contents(file_path)]
|
||||
|
||||
def _get_repo_tree_contents(repo_path, tree, progress_bar):
|
||||
for subtree in tree:
|
||||
if subtree.type == "tree":
|
||||
_get_repo_tree_contents(repo_path, subtree, progress_bar)
|
||||
else:
|
||||
assert subtree.type == "blob"
|
||||
with tqdm(desc="Loading files:", unit="item") as progress_bar:
|
||||
while repo_contents:
|
||||
file_content = repo_contents.pop(0)
|
||||
if file_content.type == "dir":
|
||||
try:
|
||||
contents = subtree.data_stream.read().decode("utf-8")
|
||||
repo_contents.extend(repo.get_contents(file_content.path))
|
||||
except Exception:
|
||||
logging.warning(f"Failed to read file: {subtree.path}")
|
||||
progress_bar.update(1) if progress_bar else None
|
||||
logging.warning(f"Failed to read directory: {file_content.path}")
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
file_text = file_content.decoded_content.decode()
|
||||
except Exception:
|
||||
logging.warning(f"Failed to read file: {file_content.path}")
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
|
||||
file_path = file_content.path
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(contents),
|
||||
"content": clean_string(file_text),
|
||||
"meta_data": {
|
||||
"url": url,
|
||||
"path": file_path,
|
||||
},
|
||||
}
|
||||
)
|
||||
if progress_bar is not None:
|
||||
progress_bar.update(1)
|
||||
|
||||
repo_tree = _get_repo_tree(repo_url, local_path)
|
||||
tree_list = list(repo_tree.traverse())
|
||||
with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
|
||||
_get_repo_tree_contents(local_path, repo_tree, progress_bar)
|
||||
progress_bar.update(1)
|
||||
|
||||
return data
|
||||
|
||||
def _github_search_repo(self, query: str) -> list[dict]:
|
||||
"""Search GitHub repo."""
|
||||
data = []
|
||||
logging.info(f"Searching github repos with query: {query}")
|
||||
results = self.client.search_repositories(query)
|
||||
# Add repo urls and descriptions
|
||||
urls = list(map(lambda x: x.html_url, results))
|
||||
descriptions = list(map(lambda x: x.description, results))
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(desc),
|
||||
"meta_data": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
for url, desc in zip(urls, descriptions)
|
||||
)
|
||||
|
||||
# Add repo contents
|
||||
for result in results:
|
||||
clone_url = result.clone_url
|
||||
logging.info(f"Cloning repository: {clone_url}")
|
||||
data = self._get_github_repo_data(clone_url)
|
||||
logging.info(f"Searching github repos with query: {query}")
|
||||
updated_query = query.split(":")[-1]
|
||||
data = self._get_github_repo_data(updated_query)
|
||||
return data
|
||||
|
||||
def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
|
||||
@@ -222,6 +192,43 @@ class GithubLoader(BaseLoader):
|
||||
)
|
||||
return data
|
||||
|
||||
def _get_github_repo_branch(self, query: str, type: str) -> list[dict]:
|
||||
"""Get file contents for specific branch"""
|
||||
|
||||
logging.info(f"Searching github repo for query: {query} is:{type}")
|
||||
pattern = r"repo:(\S+) name:(\S+)"
|
||||
match = re.search(pattern, query)
|
||||
|
||||
if match:
|
||||
repo_name = match.group(1)
|
||||
branch_name = match.group(2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Repository name and Branch name not found, instead found this \
|
||||
Repo: {repo_name}, Branch: {branch_name}"
|
||||
)
|
||||
|
||||
data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name)
|
||||
return data
|
||||
|
||||
def _get_github_repo_file(self, query: str, type: str) -> list[dict]:
|
||||
"""Get specific file content"""
|
||||
|
||||
logging.info(f"Searching github repo for query: {query} is:{type}")
|
||||
pattern = r"repo:(\S+) path:(\S+)"
|
||||
match = re.search(pattern, query)
|
||||
|
||||
if match:
|
||||
repo_name = match.group(1)
|
||||
file_path = match.group(2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}"
|
||||
)
|
||||
|
||||
data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path)
|
||||
return data
|
||||
|
||||
def _search_github_data(self, search_type: str, query: str):
|
||||
"""Search github data."""
|
||||
if search_type == "code":
|
||||
@@ -232,6 +239,10 @@ class GithubLoader(BaseLoader):
|
||||
data = self._github_search_issues_and_pr(query, search_type)
|
||||
elif search_type == "pr":
|
||||
data = self._github_search_issues_and_pr(query, search_type)
|
||||
elif search_type == "branch":
|
||||
data = self._get_github_repo_branch(query, search_type)
|
||||
elif search_type == "file":
|
||||
data = self._get_github_repo_file(query, search_type)
|
||||
elif search_type == "discussion":
|
||||
raise ValueError("GithubLoader does not support searching discussions yet.")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user