diff --git a/docs/components/data-sources/github.mdx b/docs/components/data-sources/github.mdx
index bbce3b01..14791aca 100644
--- a/docs/components/data-sources/github.mdx
+++ b/docs/components/data-sources/github.mdx
@@ -29,11 +29,13 @@ response = app.query("What is Embedchain?")
```
The `add` function of the app will accept any valid github query with qualifiers. It only supports loading github code, repository, issues and pull-requests.
-You must provide qualifiers `type:` and `repo:` in the query. The `type:` qualifier can be a combination of `code`, `repo`, `pr`, `issue`. The `repo:` qualifier must be a valid github repository name.
+You must provide qualifiers `type:` and `repo:` in the query. The `type:` qualifier can be a combination of `code`, `repo`, `pr`, `issue`, `branch`, `file`. The `repo:` qualifier must be a valid github repository name.
- `repo:embedchain/embedchain type:repo` - to load the repository
+ - `repo:embedchain/embedchain type:branch name:feature_test` - to load the branch of the repository
+ - `repo:embedchain/embedchain type:file path:README.md` - to load the specific file of the repository
- `repo:embedchain/embedchain type:issue,pr` - to load the issues and pull-requests of the repository
- `repo:embedchain/embedchain type:issue state:closed` - to load the closed issues of the repository
diff --git a/embedchain/loaders/github.py b/embedchain/loaders/github.py
index 967dbe77..e388ab9b 100644
--- a/embedchain/loaders/github.py
+++ b/embedchain/loaders/github.py
@@ -1,7 +1,6 @@
import concurrent.futures
import hashlib
import logging
-import os
import re
import shlex
from typing import Any, Optional
@@ -14,7 +13,7 @@ from embedchain.utils.misc import clean_string
GITHUB_URL = "https://github.com"
GITHUB_API_URL = "https://api.github.com"
-VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
+VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"])
class GithubLoader(BaseLoader):
@@ -66,85 +65,56 @@ class GithubLoader(BaseLoader):
)
return data
- @staticmethod
- def _get_github_repo_data(repo_url: str):
- local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
- local_path = f"/tmp/{local_hash}"
+ def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]:
+ """Get file contents from Repo"""
data = []
- def _get_repo_tree(repo_url: str, local_path: str):
- try:
- from git import Repo
- except ImportError as e:
- raise ValueError(
- "GithubLoader requires extra dependencies. Install with `pip install --upgrade 'embedchain[github]'`" # noqa: E501
- ) from e
+ repo = self.client.get_repo(repo_name)
+ repo_contents = repo.get_contents("")
- if os.path.exists(local_path):
- logging.info("Repository already exists. Fetching updates...")
- repo = Repo(local_path)
- logging.info("Fetch completed.")
- else:
- logging.info("Cloning repository...")
- repo = Repo.clone_from(repo_url, local_path)
- logging.info("Clone completed.")
- return repo.head.commit.tree
+ if branch_name:
+ repo_contents = repo.get_contents("", ref=branch_name)
+ if file_path:
+ repo_contents = [repo.get_contents(file_path)]
- def _get_repo_tree_contents(repo_path, tree, progress_bar):
- for subtree in tree:
- if subtree.type == "tree":
- _get_repo_tree_contents(repo_path, subtree, progress_bar)
- else:
- assert subtree.type == "blob"
+ with tqdm(desc="Loading files:", unit="item") as progress_bar:
+ while repo_contents:
+ file_content = repo_contents.pop(0)
+ if file_content.type == "dir":
try:
- contents = subtree.data_stream.read().decode("utf-8")
+ repo_contents.extend(repo.get_contents(file_content.path))
except Exception:
- logging.warning(f"Failed to read file: {subtree.path}")
- progress_bar.update(1) if progress_bar else None
+ logging.warning(f"Failed to read directory: {file_content.path}")
+ progress_bar.update(1)
+ continue
+ else:
+ try:
+ file_text = file_content.decoded_content.decode()
+ except Exception:
+ logging.warning(f"Failed to read file: {file_content.path}")
+ progress_bar.update(1)
continue
- url = f"{repo_url.rstrip('.git')}/blob/main/{subtree.path}"
+ file_path = file_content.path
data.append(
{
- "content": clean_string(contents),
+ "content": clean_string(file_text),
"meta_data": {
- "url": url,
+ "path": file_path,
},
}
)
- if progress_bar is not None:
- progress_bar.update(1)
- repo_tree = _get_repo_tree(repo_url, local_path)
- tree_list = list(repo_tree.traverse())
- with tqdm(total=len(tree_list), desc="Loading files:", unit="item") as progress_bar:
- _get_repo_tree_contents(local_path, repo_tree, progress_bar)
+ progress_bar.update(1)
return data
def _github_search_repo(self, query: str) -> list[dict]:
"""Search GitHub repo."""
- data = []
- logging.info(f"Searching github repos with query: {query}")
- results = self.client.search_repositories(query)
- # Add repo urls and descriptions
- urls = list(map(lambda x: x.html_url, results))
- descriptions = list(map(lambda x: x.description, results))
- data.append(
- {
- "content": clean_string(desc),
- "meta_data": {
- "url": url,
- },
- }
- for url, desc in zip(urls, descriptions)
- )
- # Add repo contents
- for result in results:
- clone_url = result.clone_url
- logging.info(f"Cloning repository: {clone_url}")
- data = self._get_github_repo_data(clone_url)
+ logging.info(f"Searching github repos with query: {query}")
+ updated_query = query.split(":")[-1]
+ data = self._get_github_repo_data(updated_query)
return data
def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
@@ -222,6 +192,43 @@ class GithubLoader(BaseLoader):
)
return data
+ def _get_github_repo_branch(self, query: str, type: str) -> list[dict]:
+ """Get file contents for specific branch"""
+
+ logging.info(f"Searching github repo for query: {query} is:{type}")
+ pattern = r"repo:(\S+) name:(\S+)"
+ match = re.search(pattern, query)
+
+ if match:
+ repo_name = match.group(1)
+ branch_name = match.group(2)
+ else:
+ raise ValueError(
+ f"Repository name and Branch name not found, instead found this \
+ Repo: {repo_name}, Branch: {branch_name}"
+ )
+
+ data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name)
+ return data
+
+ def _get_github_repo_file(self, query: str, type: str) -> list[dict]:
+ """Get specific file content"""
+
+ logging.info(f"Searching github repo for query: {query} is:{type}")
+ pattern = r"repo:(\S+) path:(\S+)"
+ match = re.search(pattern, query)
+
+ if match:
+ repo_name = match.group(1)
+ file_path = match.group(2)
+ else:
+ raise ValueError(
+ f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}"
+ )
+
+ data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path)
+ return data
+
def _search_github_data(self, search_type: str, query: str):
"""Search github data."""
if search_type == "code":
@@ -232,6 +239,10 @@ class GithubLoader(BaseLoader):
data = self._github_search_issues_and_pr(query, search_type)
elif search_type == "pr":
data = self._github_search_issues_and_pr(query, search_type)
+ elif search_type == "branch":
+ data = self._get_github_repo_branch(query, search_type)
+ elif search_type == "file":
+ data = self._get_github_repo_file(query, search_type)
elif search_type == "discussion":
raise ValueError("GithubLoader does not support searching discussions yet.")
else: