[Refactor] Update dependencies and loaders (#1062)
This commit is contained in:
@@ -1,123 +1,142 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import quopri
|
||||
from email import message_from_bytes
|
||||
from email.utils import parsedate_to_datetime
|
||||
from textwrap import dedent
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
try:
|
||||
from llama_hub.gmail.base import GmailReader
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from googleapiclient.discovery import build
|
||||
except ImportError:
|
||||
raise ImportError("Gmail requires extra dependencies. Install with `pip install embedchain[gmail]`") from None
|
||||
raise ImportError(
|
||||
'Gmail requires extra dependencies. Install with `pip install --upgrade "embedchain[gmail]"`'
|
||||
) from None
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
def get_header(text: str, header: str) -> str:
|
||||
start_string_position = text.find(header)
|
||||
pos_start = text.find(":", start_string_position) + 1
|
||||
pos_end = text.find("\n", pos_start)
|
||||
header = text[pos_start:pos_end]
|
||||
return header.strip()
|
||||
class GmailReader:
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
|
||||
def __init__(self, query: str, service=None, results_per_page: int = 10):
|
||||
self.query = query
|
||||
self.service = service or self._initialize_service()
|
||||
self.results_per_page = results_per_page
|
||||
|
||||
@staticmethod
|
||||
def _initialize_service():
|
||||
credentials = GmailReader._get_credentials()
|
||||
return build("gmail", "v1", credentials=credentials)
|
||||
|
||||
@staticmethod
|
||||
def _get_credentials():
|
||||
if not os.path.exists("credentials.json"):
|
||||
raise FileNotFoundError("Missing 'credentials.json'. Download it from your Google Developer account.")
|
||||
|
||||
creds = (
|
||||
Credentials.from_authorized_user_file("token.json", GmailReader.SCOPES)
|
||||
if os.path.exists("token.json")
|
||||
else None
|
||||
)
|
||||
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file("credentials.json", GmailReader.SCOPES)
|
||||
creds = flow.run_local_server(port=8080)
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
return creds
|
||||
|
||||
def load_emails(self) -> List[Dict]:
|
||||
response = self.service.users().messages().list(userId="me", q=self.query).execute()
|
||||
messages = response.get("messages", [])
|
||||
|
||||
return [self._parse_email(self._get_email(message["id"])) for message in messages]
|
||||
|
||||
def _get_email(self, message_id: str):
|
||||
raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute()
|
||||
return base64.urlsafe_b64decode(raw_message["raw"])
|
||||
|
||||
def _parse_email(self, raw_email) -> Dict:
|
||||
mime_msg = message_from_bytes(raw_email)
|
||||
return {
|
||||
"subject": self._get_header(mime_msg, "Subject"),
|
||||
"from": self._get_header(mime_msg, "From"),
|
||||
"to": self._get_header(mime_msg, "To"),
|
||||
"date": self._format_date(mime_msg),
|
||||
"body": self._get_body(mime_msg),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_header(mime_msg, header_name: str) -> str:
|
||||
return mime_msg.get(header_name, "")
|
||||
|
||||
@staticmethod
|
||||
def _format_date(mime_msg) -> Optional[str]:
|
||||
date_header = GmailReader._get_header(mime_msg, "Date")
|
||||
return parsedate_to_datetime(date_header).isoformat() if date_header else None
|
||||
|
||||
@staticmethod
|
||||
def _get_body(mime_msg) -> str:
|
||||
def decode_payload(part):
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
try:
|
||||
return part.get_payload(decode=True).decode(charset)
|
||||
except UnicodeDecodeError:
|
||||
return part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
|
||||
if mime_msg.is_multipart():
|
||||
for part in mime_msg.walk():
|
||||
ctype = part.get_content_type()
|
||||
cdispo = str(part.get("Content-Disposition"))
|
||||
|
||||
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||
return decode_payload(part)
|
||||
elif ctype == "text/html":
|
||||
return decode_payload(part)
|
||||
else:
|
||||
return decode_payload(mime_msg)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class GmailLoader(BaseLoader):
|
||||
def load_data(self, query):
|
||||
"""Load data from gmail."""
|
||||
if not os.path.isfile("credentials.json"):
|
||||
raise FileNotFoundError(
|
||||
"You must download the valid credentials file from your google \
|
||||
dev account. Refer this `https://cloud.google.com/docs/authentication/api-keys`"
|
||||
)
|
||||
|
||||
loader = GmailReader(query=query, service=None, results_per_page=20)
|
||||
documents = loader.load_data()
|
||||
logging.info(f"Gmail Loader: {len(documents)} mails found for query- {query}")
|
||||
def load_data(self, query: str):
|
||||
reader = GmailReader(query=query)
|
||||
emails = reader.load_emails()
|
||||
logging.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
|
||||
|
||||
data = []
|
||||
data_contents = []
|
||||
logging.info(f"Gmail Loader: {len(documents)} mails found")
|
||||
for document in documents:
|
||||
original_size = len(document.text)
|
||||
for email in emails:
|
||||
content = self._process_email(email)
|
||||
data.append({"content": content, "meta_data": email})
|
||||
|
||||
snippet = document.metadata.get("snippet")
|
||||
meta_data = {
|
||||
"url": document.metadata.get("id"),
|
||||
"date": get_header(document.text, "Date"),
|
||||
"subject": get_header(document.text, "Subject"),
|
||||
"from": get_header(document.text, "From"),
|
||||
"to": get_header(document.text, "To"),
|
||||
"search_query": query,
|
||||
}
|
||||
return {"doc_id": self._generate_doc_id(query, data), "data": data}
|
||||
|
||||
# Decode
|
||||
decoded_bytes = quopri.decodestring(document.text)
|
||||
decoded_str = decoded_bytes.decode("utf-8", errors="replace")
|
||||
@staticmethod
|
||||
def _process_email(email: Dict) -> str:
|
||||
content = BeautifulSoup(email["body"], "html.parser").get_text()
|
||||
content = clean_string(content)
|
||||
return dedent(
|
||||
f"""
|
||||
Email from '{email['from']}' to '{email['to']}'
|
||||
Subject: {email['subject']}
|
||||
Date: {email['date']}
|
||||
Content: {content}
|
||||
"""
|
||||
)
|
||||
|
||||
# Slice
|
||||
mail_start = decoded_str.find("<!DOCTYPE")
|
||||
email_data = decoded_str[mail_start:]
|
||||
|
||||
# Web Page HTML Processing
|
||||
soup = BeautifulSoup(email_data, "html.parser")
|
||||
|
||||
tags_to_exclude = [
|
||||
"nav",
|
||||
"aside",
|
||||
"form",
|
||||
"header",
|
||||
"noscript",
|
||||
"svg",
|
||||
"canvas",
|
||||
"footer",
|
||||
"script",
|
||||
"style",
|
||||
]
|
||||
|
||||
for tag in soup(tags_to_exclude):
|
||||
tag.decompose()
|
||||
|
||||
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
|
||||
for id in ids_to_exclude:
|
||||
tags = soup.find_all(id=id)
|
||||
for tag in tags:
|
||||
tag.decompose()
|
||||
|
||||
classes_to_exclude = [
|
||||
"elementor-location-header",
|
||||
"navbar-header",
|
||||
"nav",
|
||||
"header-sidebar-wrapper",
|
||||
"blog-sidebar-wrapper",
|
||||
"related-posts",
|
||||
]
|
||||
|
||||
for class_name in classes_to_exclude:
|
||||
tags = soup.find_all(class_=class_name)
|
||||
for tag in tags:
|
||||
tag.decompose()
|
||||
|
||||
content = soup.get_text()
|
||||
content = clean_string(content)
|
||||
|
||||
cleaned_size = len(content)
|
||||
if original_size != 0:
|
||||
logging.info(
|
||||
f"[{id}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
|
||||
)
|
||||
|
||||
result = f"""
|
||||
email from '{meta_data.get('from')}' to '{meta_data.get('to')}'
|
||||
subject: {meta_data.get('subject')}
|
||||
date: {meta_data.get('date')}
|
||||
preview: {snippet}
|
||||
content: f{content}
|
||||
"""
|
||||
data_content = dedent(result)
|
||||
data.append({"content": data_content, "meta_data": meta_data})
|
||||
data_contents.append(data_content)
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
|
||||
response_data = {"doc_id": doc_id, "data": data}
|
||||
return response_data
|
||||
@staticmethod
|
||||
def _generate_doc_id(query: str, data: List[Dict]) -> str:
|
||||
content_strings = [email["content"] for email in data]
|
||||
return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest()
|
||||
|
||||
@@ -2,29 +2,43 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string, is_valid_json_string
|
||||
|
||||
|
||||
class JSONReader:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the JSONReader."""
|
||||
pass
|
||||
|
||||
def load_data(self, json_data: Union[Dict, str]) -> List[str]:
|
||||
"""Load data from a JSON structure.
|
||||
|
||||
Args:
|
||||
json_data (Union[Dict, str]): The JSON data to load.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of strings representing the leaf nodes of the JSON.
|
||||
"""
|
||||
if isinstance(json_data, str):
|
||||
json_data = json.loads(json_data)
|
||||
else:
|
||||
json_data = json_data
|
||||
|
||||
json_output = json.dumps(json_data, indent=0)
|
||||
lines = json_output.split("\n")
|
||||
useful_lines = [line for line in lines if not re.match(r"^[{}\[\],]*$", line)]
|
||||
return ["\n".join(useful_lines)]
|
||||
|
||||
|
||||
VALID_URL_PATTERN = "^https:\/\/[0-9A-z.]+.[0-9A-z.]+.[a-z]+\/.*\.json$"
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _get_llama_hub_loader():
|
||||
try:
|
||||
from llama_hub.jsondata.base import \
|
||||
JSONDataReader as LLHUBJSONLoader
|
||||
except ImportError as e:
|
||||
raise Exception(
|
||||
f"Failed to install required packages: {e}, \
|
||||
install them using `pip install --upgrade 'embedchain[json]`"
|
||||
)
|
||||
|
||||
return LLHUBJSONLoader()
|
||||
|
||||
@staticmethod
|
||||
def _check_content(content):
|
||||
if not isinstance(content, str):
|
||||
@@ -40,14 +54,13 @@ class JSONLoader(BaseLoader):
|
||||
"""Load a json file. Each data point is a key value pair."""
|
||||
|
||||
JSONLoader._check_content(content)
|
||||
loader = JSONLoader._get_llama_hub_loader()
|
||||
loader = JSONReader()
|
||||
|
||||
data = []
|
||||
data_content = []
|
||||
|
||||
content_url_str = content
|
||||
|
||||
# Load json data from various sources.
|
||||
if os.path.isfile(content):
|
||||
with open(content, "r", encoding="utf-8") as json_file:
|
||||
json_data = json.load(json_file)
|
||||
@@ -68,7 +81,8 @@ class JSONLoader(BaseLoader):
|
||||
|
||||
docs = loader.load_data(json_data)
|
||||
for doc in docs:
|
||||
doc_content = clean_string(doc.text)
|
||||
text = doc if isinstance(doc, str) else doc["text"]
|
||||
doc_content = clean_string(text)
|
||||
data.append({"content": doc_content, "meta_data": {"url": content_url_str}})
|
||||
data_content.append(doc_content)
|
||||
|
||||
|
||||
@@ -1,39 +1,111 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from llama_hub.notion.base import NotionPageReader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Notion requires extra dependencies. Install with `pip install --upgrade embedchain[community]`"
|
||||
) from None
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
class NotionDocument:
|
||||
"""
|
||||
A simple Document class to hold the text and additional information of a page.
|
||||
"""
|
||||
|
||||
def __init__(self, text: str, extra_info: Dict[str, Any]):
|
||||
self.text = text
|
||||
self.extra_info = extra_info
|
||||
|
||||
|
||||
class NotionPageLoader:
|
||||
"""
|
||||
Notion Page Loader.
|
||||
Reads a set of Notion pages.
|
||||
"""
|
||||
|
||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
|
||||
def __init__(self, integration_token: Optional[str] = None) -> None:
|
||||
"""Initialize with Notion integration token."""
|
||||
if integration_token is None:
|
||||
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
self.token = integration_token
|
||||
self.headers = {
|
||||
"Authorization": "Bearer " + self.token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
|
||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||
"""Read a block from Notion."""
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = self.BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
res = requests.get(block_url, headers=self.headers)
|
||||
data = res.json()
|
||||
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
prefix = "\t" * num_tabs
|
||||
cur_result_text_arr.append(prefix + text)
|
||||
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
if has_children:
|
||||
children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
result_lines_arr.append(cur_result_text)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def load_data(self, page_ids: List[str]) -> List[NotionDocument]:
|
||||
"""Load data from the given list of page IDs."""
|
||||
docs = []
|
||||
for page_id in page_ids:
|
||||
page_text = self._read_block(page_id)
|
||||
docs.append(NotionDocument(text=page_text, extra_info={"page_id": page_id}))
|
||||
return docs
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class NotionLoader(BaseLoader):
|
||||
def load_data(self, source):
|
||||
"""Load data from a PDF file."""
|
||||
"""Load data from a Notion URL."""
|
||||
|
||||
# Reformat Id to match notion expectation
|
||||
id = source[-32:]
|
||||
formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}"
|
||||
logging.debug(f"Extracted notion page id as: {formatted_id}")
|
||||
|
||||
# Get page through the notion api
|
||||
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
|
||||
reader = NotionPageReader(integration_token=integration_token)
|
||||
reader = NotionPageLoader(integration_token=integration_token)
|
||||
documents = reader.load_data(page_ids=[formatted_id])
|
||||
|
||||
# Extract text
|
||||
raw_text = documents[0].text
|
||||
|
||||
# Clean text
|
||||
text = clean_string(raw_text)
|
||||
doc_id = hashlib.sha256((text + source).encode()).hexdigest()
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user