[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)

This commit is contained in:
Sandra Serrano
2024-01-08 19:47:46 +01:00
committed by GitHub
parent 62c0c52e31
commit 2496ed133e
41 changed files with 133 additions and 103 deletions

View File

@@ -5,7 +5,7 @@ class BaseLoader(JSONSerializable):
def __init__(self):
pass
def load_data():
def load_data(self, url):
"""
Implemented by child classes
"""

View File

@@ -32,7 +32,7 @@ class DirectoryLoader(BaseLoader):
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
for error in self.errors:
logging.warn(error)
logging.warning(error)
return {"doc_id": doc_id, "data": data_list}

View File

@@ -49,7 +49,8 @@ class DocsSiteLoader(BaseLoader):
urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
return urls
def _load_data_from_url(self, url):
@staticmethod
def _load_data_from_url(url: str) -> list:
response = requests.get(url)
if response.status_code != 200:
logging.info(f"Failed to fetch the website: {response.status_code}")

View File

@@ -18,7 +18,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
class GithubLoader(BaseLoader):
"""Load data from github search query."""
"""Load data from GitHub search query."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__()
@@ -48,7 +48,7 @@ class GithubLoader(BaseLoader):
self.client = None
def _github_search_code(self, query: str):
"""Search github code."""
"""Search GitHub code."""
data = []
results = self.client.search_code(query)
for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
@@ -66,7 +66,8 @@ class GithubLoader(BaseLoader):
)
return data
def _get_github_repo_data(self, repo_url: str):
@staticmethod
def _get_github_repo_data(repo_url: str):
local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
local_path = f"/tmp/{local_hash}"
data = []
@@ -121,14 +122,14 @@ class GithubLoader(BaseLoader):
return data
def _github_search_repo(self, query: str):
"""Search github repo."""
def _github_search_repo(self, query: str) -> list[dict]:
"""Search GitHub repo."""
data = []
logging.info(f"Searching github repos with query: {query}")
results = self.client.search_repositories(query)
# Add repo urls and descriptions
urls = list(map(lambda x: x.html_url, results))
discriptions = list(map(lambda x: x.description, results))
descriptions = list(map(lambda x: x.description, results))
data.append(
{
"content": clean_string(desc),
@@ -136,7 +137,7 @@ class GithubLoader(BaseLoader):
"url": url,
},
}
for url, desc in zip(urls, discriptions)
for url, desc in zip(urls, descriptions)
)
# Add repo contents
@@ -146,8 +147,8 @@ class GithubLoader(BaseLoader):
data = self._get_github_repo_data(clone_url)
return data
def _github_search_issues_and_pr(self, query: str, type: str):
"""Search github issues and PRs."""
def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
"""Search GitHub issues and PRs."""
data = []
query = f"{query} is:{type}"
@@ -161,7 +162,7 @@ class GithubLoader(BaseLoader):
title = result.title
body = result.body
if not body:
logging.warn(f"Skipping issue because empty content for: {url}")
logging.warning(f"Skipping issue because empty content for: {url}")
continue
labels = " ".join([label.name for label in result.labels])
issue_comments = result.get_comments()
@@ -186,7 +187,7 @@ class GithubLoader(BaseLoader):
# need to test more for discussion
def _github_search_discussions(self, query: str):
"""Search github discussions."""
"""Search GitHub discussions."""
data = []
query = f"{query} is:discussion"
@@ -202,7 +203,7 @@ class GithubLoader(BaseLoader):
title = discussion.title
body = discussion.body
if not body:
logging.warn(f"Skipping discussion because empty content for: {url}")
logging.warning(f"Skipping discussion because empty content for: {url}")
continue
comments = []
comments_created_at = []
@@ -233,11 +234,14 @@ class GithubLoader(BaseLoader):
data = self._github_search_issues_and_pr(query, search_type)
elif search_type == "discussion":
raise ValueError("GithubLoader does not support searching discussions yet.")
else:
raise NotImplementedError(f"{search_type} not supported")
return data
def _get_valid_github_query(self, query: str):
"""Check if query is valid and return search types and valid github query."""
@staticmethod
def _get_valid_github_query(query: str):
"""Check if query is valid and return search types and valid GitHub query."""
query_terms = shlex.split(query)
# query must provide repo to load data from
if len(query_terms) < 1 or "repo:" not in query:
@@ -273,7 +277,7 @@ class GithubLoader(BaseLoader):
return types, query
def load_data(self, search_query: str, max_results: int = 1000):
"""Load data from github search query."""
"""Load data from GitHub search query."""
if not self.client:
raise ValueError(

View File

@@ -20,7 +20,8 @@ class ImageLoader(BaseLoader):
self.api_key = api_key or os.environ["OPENAI_API_KEY"]
self.client = OpenAI(api_key=self.api_key)
def _encode_image(self, image_path: str):
@staticmethod
def _encode_image(image_path: str):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

View File

@@ -15,7 +15,8 @@ class JSONReader:
"""Initialize the JSONReader."""
pass
def load_data(self, json_data: Union[Dict, str]) -> List[str]:
@staticmethod
def load_data(json_data: Union[Dict, str]) -> List[str]:
"""Load data from a JSON structure.
Args:

View File

@@ -39,7 +39,8 @@ class MySQLLoader(BaseLoader):
Refer `https://docs.embedchain.ai/data-sources/mysql`.",
)
def _check_query(self, query):
@staticmethod
def _check_query(query):
if not isinstance(query, str):
raise ValueError(
f"Invalid mysql query: {query}",

View File

@@ -24,7 +24,6 @@ class PostgresLoader(BaseLoader):
Run `pip install --upgrade 'embedchain[postgres]'`"
) from e
config_info = ""
if "url" in config:
config_info = config.get("url")
else:
@@ -37,7 +36,8 @@ class PostgresLoader(BaseLoader):
self.connection = psycopg.connect(conninfo=config_info)
self.cursor = self.connection.cursor()
def _check_query(self, query):
@staticmethod
def _check_query(query):
if not isinstance(query, str):
raise ValueError(
f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501

View File

@@ -56,7 +56,8 @@ class SlackLoader(BaseLoader):
)
logging.info("Slack Loader setup successful!")
def _check_query(self, query):
@staticmethod
def _check_query(query):
if not isinstance(query, str):
raise ValueError(
f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501

View File

@@ -8,7 +8,7 @@ from embedchain.utils.misc import clean_string
@register_deserializable
class UnstructuredLoader(BaseLoader):
def load_data(self, url):
"""Load data from a Unstructured file."""
"""Load data from an Unstructured file."""
try:
from langchain.document_loaders import UnstructuredFileLoader
except ImportError:

View File

@@ -21,7 +21,7 @@ class WebPageLoader(BaseLoader):
_session = requests.Session()
def load_data(self, url):
"""Load data from a web page using a shared requests session."""
"""Load data from a web page using a shared requests' session."""
response = self._session.get(url, timeout=30)
response.raise_for_status()
data = response.content
@@ -40,7 +40,8 @@ class WebPageLoader(BaseLoader):
],
}
def _get_clean_content(self, html, url) -> str:
@staticmethod
def _get_clean_content(html, url) -> str:
soup = BeautifulSoup(html, "html.parser")
original_size = len(str(soup.get_text()))
@@ -60,8 +61,8 @@ class WebPageLoader(BaseLoader):
tag.decompose()
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
for id in ids_to_exclude:
tags = soup.find_all(id=id)
for id_ in ids_to_exclude:
tags = soup.find_all(id=id_)
for tag in tags:
tag.decompose()