diff --git a/docs/components/data-sources/github.mdx b/docs/components/data-sources/github.mdx index bbce3b01..14791aca 100644 --- a/docs/components/data-sources/github.mdx +++ b/docs/components/data-sources/github.mdx @@ -29,11 +29,13 @@ response = app.query("What is Embedchain?") ``` The `add` function of the app will accept any valid github query with qualifiers. It only supports loading github code, repository, issues and pull-requests. -You must provide qualifiers `type:` and `repo:` in the query. The `type:` qualifier can be a combination of `code`, `repo`, `pr`, `issue`. The `repo:` qualifier must be a valid github repository name. +You must provide qualifiers `type:` and `repo:` in the query. The `type:` qualifier can be a combination of `code`, `repo`, `pr`, `issue`, `branch`, `file`. The `repo:` qualifier must be a valid github repository name. - `repo:embedchain/embedchain type:repo` - to load the repository + - `repo:embedchain/embedchain type:branch name:feature_test` - to load the branch of the repository + - `repo:embedchain/embedchain type:file path:README.md` - to load the specific file of the repository - `repo:embedchain/embedchain type:issue,pr` - to load the issues and pull-requests of the repository - `repo:embedchain/embedchain type:issue state:closed` - to load the closed issues of the repository diff --git a/embedchain/loaders/github.py b/embedchain/loaders/github.py index 967dbe77..e388ab9b 100644 --- a/embedchain/loaders/github.py +++ b/embedchain/loaders/github.py @@ -1,7 +1,6 @@ import concurrent.futures import hashlib import logging -import os import re import shlex from typing import Any, Optional @@ -14,7 +13,7 @@ from embedchain.utils.misc import clean_string GITHUB_URL = "https://github.com" GITHUB_API_URL = "https://api.github.com" -VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"]) +VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"]) class GithubLoader(BaseLoader): @@ -66,85 +65,56 @@ class GithubLoader(BaseLoader): ) return data - @staticmethod - def _get_github_repo_data(repo_url: str): - local_hash = hashlib.sha256(repo_url.encode()).hexdigest() - local_path = f"/tmp/{local_hash}" + def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]: + """Get file contents from Repo""" data = [] - def _get_repo_tree(repo_url: str, local_path: str): - try: - from git import Repo - except ImportError as e: - raise ValueError( - "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501 - ) from e + repo = self.client.get_repo(repo_name) + repo_contents = repo.get_contents("") - if os.path.exists(local_path): - logging.info("Repository already exists. Fetching updates...") - repo = Repo(local_path) - logging.info("Fetch completed.") - else: - logging.info("Cloning repository...") - repo = Repo.clone_from(repo_url, local_path) - logging.info("Clone completed.") - return repo.head.commit.tree + if branch_name: + repo_contents = repo.get_contents("", ref=branch_name) + if file_path: + repo_contents = [repo.get_contents(file_path)] - def _get_repo_tree_contents(repo_path, tree, progress_bar): - for subtree in tree: - if subtree.type == "tree": - _get_repo_tree_contents(repo_path, subtree, progress_bar) - else: - assert subtree.type == "blob" + with tqdm(desc="Loading files:", unit="item") as progress_bar: + while repo_contents: + file_content = repo_contents.pop(0) + if file_content.type == "dir": try: - contents = subtree.data_stream.read().decode("utf-8") + repo_contents.extend(repo.get_contents(file_content.path)) except Exception: - logging.warning(f"Failed to read file: {subtree.path}") - progress_bar.update(1) if progress_bar else None + logging.warning(f"Failed to read directory: {file_content.path}") + progress_bar.update(1) + continue + else: + try: + file_text = file_content.decoded_content.decode() + except Exception: + logging.warning(f"Failed to read file: {file_content.path}") + progress_bar.update(1) continue - url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}" + file_path = file_content.path data.append( { - "content": clean_string(contents), + "content": clean_string(file_text), "meta_data": { - "url": url, + "path": file_path, }, } ) - if progress_bar is not None: - progress_bar.update(1) - repo_tree = _get_repo_tree(repo_url, local_path) - tree_list = list(repo_tree.traverse()) - with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar: - _get_repo_tree_contents(local_path, repo_tree, progress_bar) + progress_bar.update(1) return data def _github_search_repo(self, query: str) -> list[dict]: """Search GitHub repo.""" - data = [] - logging.info(f"Searching github repos with query: {query}") - results = self.client.search_repositories(query) - # Add repo urls and descriptions - urls = list(map(lambda x: x.html_url, results)) - descriptions = list(map(lambda x: x.description, results)) - data.append( - { - "content": clean_string(desc), - "meta_data": { - "url": url, - }, - } - for url, desc in zip(urls, descriptions) - ) - # Add repo contents - for result in results: - clone_url = result.clone_url - logging.info(f"Cloning repository: {clone_url}") - data = self._get_github_repo_data(clone_url) + logging.info(f"Searching github repos with query: {query}") + updated_query = query.split(":")[-1] + data = self._get_github_repo_data(updated_query) return data def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]: @@ -222,6 +192,43 @@ class GithubLoader(BaseLoader): ) return data + def _get_github_repo_branch(self, query: str, type: str) -> list[dict]: + """Get file contents for specific branch""" + + logging.info(f"Searching github repo for query: {query} is:{type}") + pattern = r"repo:(\S+) name:(\S+)" + match = re.search(pattern, query) + + if match: + repo_name = match.group(1) + branch_name = match.group(2) + else: + raise ValueError( + f"Repository name and Branch name not found, instead found this \ + Repo: {repo_name}, Branch: {branch_name}" + ) + + data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name) + return data + + def _get_github_repo_file(self, query: str, type: str) -> list[dict]: + """Get specific file content""" + + logging.info(f"Searching github repo for query: {query} is:{type}") + pattern = r"repo:(\S+) path:(\S+)" + match = re.search(pattern, query) + + if match: + repo_name = match.group(1) + file_path = match.group(2) + else: + raise ValueError( + f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}" + ) + + data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path) + return data + def _search_github_data(self, search_type: str, query: str): """Search github data.""" if search_type == "code": @@ -232,6 +239,10 @@ class GithubLoader(BaseLoader): data = self._github_search_issues_and_pr(query, search_type) elif search_type == "pr": data = self._github_search_issues_and_pr(query, search_type) + elif search_type == "branch": + data = self._get_github_repo_branch(query, search_type) + elif search_type == "file": + data = self._get_github_repo_file(query, search_type) elif search_type == "discussion": raise ValueError("GithubLoader does not support searching discussions yet.") else: