[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)
This commit is contained in:
@@ -5,7 +5,7 @@ class BaseLoader(JSONSerializable):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_data():
|
||||
def load_data(self, url):
|
||||
"""
|
||||
Implemented by child classes
|
||||
"""
|
||||
|
||||
@@ -32,7 +32,7 @@ class DirectoryLoader(BaseLoader):
|
||||
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
|
||||
|
||||
for error in self.errors:
|
||||
logging.warn(error)
|
||||
logging.warning(error)
|
||||
|
||||
return {"doc_id": doc_id, "data": data_list}
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ class DocsSiteLoader(BaseLoader):
|
||||
urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
|
||||
return urls
|
||||
|
||||
def _load_data_from_url(self, url):
|
||||
@staticmethod
|
||||
def _load_data_from_url(url: str) -> list:
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
logging.info(f"Failed to fetch the website: {response.status_code}")
|
||||
|
||||
@@ -18,7 +18,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Load data from github search query."""
|
||||
"""Load data from GitHub search query."""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
@@ -48,7 +48,7 @@ class GithubLoader(BaseLoader):
|
||||
self.client = None
|
||||
|
||||
def _github_search_code(self, query: str):
|
||||
"""Search github code."""
|
||||
"""Search GitHub code."""
|
||||
data = []
|
||||
results = self.client.search_code(query)
|
||||
for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
|
||||
@@ -66,7 +66,8 @@ class GithubLoader(BaseLoader):
|
||||
)
|
||||
return data
|
||||
|
||||
def _get_github_repo_data(self, repo_url: str):
|
||||
@staticmethod
|
||||
def _get_github_repo_data(repo_url: str):
|
||||
local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
|
||||
local_path = f"/tmp/{local_hash}"
|
||||
data = []
|
||||
@@ -121,14 +122,14 @@ class GithubLoader(BaseLoader):
|
||||
|
||||
return data
|
||||
|
||||
def _github_search_repo(self, query: str):
|
||||
"""Search github repo."""
|
||||
def _github_search_repo(self, query: str) -> list[dict]:
|
||||
"""Search GitHub repo."""
|
||||
data = []
|
||||
logging.info(f"Searching github repos with query: {query}")
|
||||
results = self.client.search_repositories(query)
|
||||
# Add repo urls and descriptions
|
||||
urls = list(map(lambda x: x.html_url, results))
|
||||
discriptions = list(map(lambda x: x.description, results))
|
||||
descriptions = list(map(lambda x: x.description, results))
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(desc),
|
||||
@@ -136,7 +137,7 @@ class GithubLoader(BaseLoader):
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
for url, desc in zip(urls, discriptions)
|
||||
for url, desc in zip(urls, descriptions)
|
||||
)
|
||||
|
||||
# Add repo contents
|
||||
@@ -146,8 +147,8 @@ class GithubLoader(BaseLoader):
|
||||
data = self._get_github_repo_data(clone_url)
|
||||
return data
|
||||
|
||||
def _github_search_issues_and_pr(self, query: str, type: str):
|
||||
"""Search github issues and PRs."""
|
||||
def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
|
||||
"""Search GitHub issues and PRs."""
|
||||
data = []
|
||||
|
||||
query = f"{query} is:{type}"
|
||||
@@ -161,7 +162,7 @@ class GithubLoader(BaseLoader):
|
||||
title = result.title
|
||||
body = result.body
|
||||
if not body:
|
||||
logging.warn(f"Skipping issue because empty content for: {url}")
|
||||
logging.warning(f"Skipping issue because empty content for: {url}")
|
||||
continue
|
||||
labels = " ".join([label.name for label in result.labels])
|
||||
issue_comments = result.get_comments()
|
||||
@@ -186,7 +187,7 @@ class GithubLoader(BaseLoader):
|
||||
|
||||
# need to test more for discussion
|
||||
def _github_search_discussions(self, query: str):
|
||||
"""Search github discussions."""
|
||||
"""Search GitHub discussions."""
|
||||
data = []
|
||||
|
||||
query = f"{query} is:discussion"
|
||||
@@ -202,7 +203,7 @@ class GithubLoader(BaseLoader):
|
||||
title = discussion.title
|
||||
body = discussion.body
|
||||
if not body:
|
||||
logging.warn(f"Skipping discussion because empty content for: {url}")
|
||||
logging.warning(f"Skipping discussion because empty content for: {url}")
|
||||
continue
|
||||
comments = []
|
||||
comments_created_at = []
|
||||
@@ -233,11 +234,14 @@ class GithubLoader(BaseLoader):
|
||||
data = self._github_search_issues_and_pr(query, search_type)
|
||||
elif search_type == "discussion":
|
||||
raise ValueError("GithubLoader does not support searching discussions yet.")
|
||||
else:
|
||||
raise NotImplementedError(f"{search_type} not supported")
|
||||
|
||||
return data
|
||||
|
||||
def _get_valid_github_query(self, query: str):
|
||||
"""Check if query is valid and return search types and valid github query."""
|
||||
@staticmethod
|
||||
def _get_valid_github_query(query: str):
|
||||
"""Check if query is valid and return search types and valid GitHub query."""
|
||||
query_terms = shlex.split(query)
|
||||
# query must provide repo to load data from
|
||||
if len(query_terms) < 1 or "repo:" not in query:
|
||||
@@ -273,7 +277,7 @@ class GithubLoader(BaseLoader):
|
||||
return types, query
|
||||
|
||||
def load_data(self, search_query: str, max_results: int = 1000):
|
||||
"""Load data from github search query."""
|
||||
"""Load data from GitHub search query."""
|
||||
|
||||
if not self.client:
|
||||
raise ValueError(
|
||||
|
||||
@@ -20,7 +20,8 @@ class ImageLoader(BaseLoader):
|
||||
self.api_key = api_key or os.environ["OPENAI_API_KEY"]
|
||||
self.client = OpenAI(api_key=self.api_key)
|
||||
|
||||
def _encode_image(self, image_path: str):
|
||||
@staticmethod
|
||||
def _encode_image(image_path: str):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@ class JSONReader:
|
||||
"""Initialize the JSONReader."""
|
||||
pass
|
||||
|
||||
def load_data(self, json_data: Union[Dict, str]) -> List[str]:
|
||||
@staticmethod
|
||||
def load_data(json_data: Union[Dict, str]) -> List[str]:
|
||||
"""Load data from a JSON structure.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -39,7 +39,8 @@ class MySQLLoader(BaseLoader):
|
||||
Refer `https://docs.embedchain.ai/data-sources/mysql`.",
|
||||
)
|
||||
|
||||
def _check_query(self, query):
|
||||
@staticmethod
|
||||
def _check_query(query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid mysql query: {query}",
|
||||
|
||||
@@ -24,7 +24,6 @@ class PostgresLoader(BaseLoader):
|
||||
Run `pip install --upgrade 'embedchain[postgres]'`"
|
||||
) from e
|
||||
|
||||
config_info = ""
|
||||
if "url" in config:
|
||||
config_info = config.get("url")
|
||||
else:
|
||||
@@ -37,7 +36,8 @@ class PostgresLoader(BaseLoader):
|
||||
self.connection = psycopg.connect(conninfo=config_info)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
def _check_query(self, query):
|
||||
@staticmethod
|
||||
def _check_query(query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501
|
||||
|
||||
@@ -56,7 +56,8 @@ class SlackLoader(BaseLoader):
|
||||
)
|
||||
logging.info("Slack Loader setup successful!")
|
||||
|
||||
def _check_query(self, query):
|
||||
@staticmethod
|
||||
def _check_query(query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
|
||||
|
||||
@@ -8,7 +8,7 @@ from embedchain.utils.misc import clean_string
|
||||
@register_deserializable
|
||||
class UnstructuredLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a Unstructured file."""
|
||||
"""Load data from an Unstructured file."""
|
||||
try:
|
||||
from langchain.document_loaders import UnstructuredFileLoader
|
||||
except ImportError:
|
||||
|
||||
@@ -21,7 +21,7 @@ class WebPageLoader(BaseLoader):
|
||||
_session = requests.Session()
|
||||
|
||||
def load_data(self, url):
|
||||
"""Load data from a web page using a shared requests session."""
|
||||
"""Load data from a web page using a shared requests' session."""
|
||||
response = self._session.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
@@ -40,7 +40,8 @@ class WebPageLoader(BaseLoader):
|
||||
],
|
||||
}
|
||||
|
||||
def _get_clean_content(self, html, url) -> str:
|
||||
@staticmethod
|
||||
def _get_clean_content(html, url) -> str:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
original_size = len(str(soup.get_text()))
|
||||
|
||||
@@ -60,8 +61,8 @@ class WebPageLoader(BaseLoader):
|
||||
tag.decompose()
|
||||
|
||||
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
|
||||
for id in ids_to_exclude:
|
||||
tags = soup.find_all(id=id)
|
||||
for id_ in ids_to_exclude:
|
||||
tags = soup.find_all(id=id_)
|
||||
for tag in tags:
|
||||
tag.decompose()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user