Rename embedchain to mem0 and open sourcing code for long term memory (#1474)
Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
0
embedchain/embedchain/loaders/__init__.py
Normal file
0
embedchain/embedchain/loaders/__init__.py
Normal file
53
embedchain/embedchain/loaders/audio.py
Normal file
53
embedchain/embedchain/loaders/audio.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import validators
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
try:
|
||||
from deepgram import DeepgramClient, PrerecordedOptions
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Audio file requires extra dependencies. Install with `pip install deepgram-sdk==3.2.7`"
|
||||
) from None
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AudioLoader(BaseLoader):
|
||||
def __init__(self):
|
||||
if not os.environ.get("DEEPGRAM_API_KEY"):
|
||||
raise ValueError("DEEPGRAM_API_KEY is not set")
|
||||
|
||||
DG_KEY = os.environ.get("DEEPGRAM_API_KEY")
|
||||
self.client = DeepgramClient(DG_KEY)
|
||||
|
||||
def load_data(self, url: str):
|
||||
"""Load data from a audio file or URL."""
|
||||
|
||||
options = PrerecordedOptions(
|
||||
model="nova-2",
|
||||
smart_format=True,
|
||||
)
|
||||
if validators.url(url):
|
||||
source = {"url": url}
|
||||
response = self.client.listen.prerecorded.v("1").transcribe_url(source, options)
|
||||
else:
|
||||
with open(url, "rb") as audio:
|
||||
source = {"buffer": audio}
|
||||
response = self.client.listen.prerecorded.v("1").transcribe_file(source, options)
|
||||
content = response["results"]["channels"][0]["alternatives"][0]["transcript"]
|
||||
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
metadata = {"url": url}
|
||||
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
12
embedchain/embedchain/loaders/base_loader.py
Normal file
12
embedchain/embedchain/loaders/base_loader.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseLoader(JSONSerializable):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_data(self, url):
|
||||
"""
|
||||
Implemented by child classes
|
||||
"""
|
||||
pass
|
||||
107
embedchain/embedchain/loaders/beehiiv.py
Normal file
107
embedchain/embedchain/loaders/beehiiv.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import is_readable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class BeehiivLoader(BaseLoader):
|
||||
"""
|
||||
This loader is used to load data from Beehiiv URLs.
|
||||
"""
|
||||
|
||||
def load_data(self, url: str):
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.builder import ParserRejectedMarkup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Beehiiv requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
|
||||
) from None
|
||||
|
||||
if not url.endswith("sitemap.xml"):
|
||||
url = url + "/sitemap.xml"
|
||||
|
||||
output = []
|
||||
# we need to set this as a header to avoid 403
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 "
|
||||
"Safari/537.36"
|
||||
),
|
||||
}
|
||||
response = requests.get(url, headers=headers)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise ValueError(
|
||||
f"""
|
||||
Failed to load {url}: {e}. Please use the root substack URL. For example, https://example.substack.com
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
ElementTree.fromstring(response.content)
|
||||
except ElementTree.ParseError:
|
||||
raise ValueError(
|
||||
f"""
|
||||
Failed to parse {url}. Please use the root substack URL. For example, https://example.substack.com
|
||||
"""
|
||||
)
|
||||
soup = BeautifulSoup(response.text, "xml")
|
||||
links = [link.text for link in soup.find_all("loc") if link.parent.name == "url" and "/p/" in link.text]
|
||||
if len(links) == 0:
|
||||
links = [link.text for link in soup.find_all("loc") if "/p/" in link.text]
|
||||
|
||||
doc_id = hashlib.sha256((" ".join(links) + url).encode()).hexdigest()
|
||||
|
||||
def serialize_response(soup: BeautifulSoup):
|
||||
data = {}
|
||||
|
||||
h1_el = soup.find("h1")
|
||||
if h1_el is not None:
|
||||
data["title"] = h1_el.text
|
||||
|
||||
description_el = soup.find("meta", {"name": "description"})
|
||||
if description_el is not None:
|
||||
data["description"] = description_el["content"]
|
||||
|
||||
content_el = soup.find("div", {"id": "content-blocks"})
|
||||
if content_el is not None:
|
||||
data["content"] = content_el.text
|
||||
|
||||
return data
|
||||
|
||||
def load_link(link: str):
|
||||
try:
|
||||
beehiiv_data = requests.get(link, headers=headers)
|
||||
beehiiv_data.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(beehiiv_data.text, "html.parser")
|
||||
data = serialize_response(soup)
|
||||
data = str(data)
|
||||
if is_readable(data):
|
||||
return data
|
||||
else:
|
||||
logger.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||
except ParserRejectedMarkup as e:
|
||||
logger.error(f"Failed to parse {link}: {e}")
|
||||
return None
|
||||
|
||||
for link in links:
|
||||
data = load_link(link)
|
||||
if data:
|
||||
output.append({"content": data, "meta_data": {"url": link}})
|
||||
# TODO: allow users to configure this
|
||||
time.sleep(1.0) # added to avoid rate limiting
|
||||
|
||||
return {"doc_id": doc_id, "data": output}
|
||||
49
embedchain/embedchain/loaders/csv.py
Normal file
49
embedchain/embedchain/loaders/csv.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import csv
|
||||
import hashlib
|
||||
from io import StringIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class CsvLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _detect_delimiter(first_line):
|
||||
delimiters = [",", "\t", ";", "|"]
|
||||
counts = {delimiter: first_line.count(delimiter) for delimiter in delimiters}
|
||||
return max(counts, key=counts.get)
|
||||
|
||||
@staticmethod
|
||||
def _get_file_content(content):
|
||||
url = urlparse(content)
|
||||
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
|
||||
raise ValueError("Not a valid URL.")
|
||||
|
||||
if url.scheme in ["http", "https"]:
|
||||
response = requests.get(content)
|
||||
response.raise_for_status()
|
||||
return StringIO(response.text)
|
||||
elif url.scheme == "file":
|
||||
path = url.path
|
||||
return open(path, newline="", encoding="utf-8") # Open the file using the path from the URI
|
||||
else:
|
||||
return open(content, newline="", encoding="utf-8") # Treat content as a regular file path
|
||||
|
||||
@staticmethod
|
||||
def load_data(content):
|
||||
"""Load a csv file with headers. Each line is a document"""
|
||||
result = []
|
||||
lines = []
|
||||
with CsvLoader._get_file_content(content) as file:
|
||||
first_line = file.readline()
|
||||
delimiter = CsvLoader._detect_delimiter(first_line)
|
||||
file.seek(0) # Reset the file pointer to the start
|
||||
reader = csv.DictReader(file, delimiter=delimiter)
|
||||
for i, row in enumerate(reader):
|
||||
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
|
||||
lines.append(line)
|
||||
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
|
||||
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": result}
|
||||
63
embedchain/embedchain/loaders/directory_loader.py
Normal file
63
embedchain/embedchain/loaders/directory_loader.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.data_formatter.data_formatter import DataFormatter
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.text_file import TextFileLoader
|
||||
from embedchain.utils.misc import detect_datatype
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DirectoryLoader(BaseLoader):
|
||||
"""Load data from a directory."""
|
||||
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
config = config or {}
|
||||
self.recursive = config.get("recursive", True)
|
||||
self.extensions = config.get("extensions", None)
|
||||
self.errors = []
|
||||
|
||||
def load_data(self, path: str):
|
||||
directory_path = Path(path)
|
||||
if not directory_path.is_dir():
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
|
||||
logger.info(f"Loading data from directory: {path}")
|
||||
data_list = self._process_directory(directory_path)
|
||||
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
|
||||
|
||||
for error in self.errors:
|
||||
logger.warning(error)
|
||||
|
||||
return {"doc_id": doc_id, "data": data_list}
|
||||
|
||||
def _process_directory(self, directory_path: Path):
|
||||
data_list = []
|
||||
for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"):
|
||||
# don't include dotfiles
|
||||
if file_path.name.startswith("."):
|
||||
continue
|
||||
if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)):
|
||||
loader = self._predict_loader(file_path)
|
||||
data_list.extend(loader.load_data(str(file_path))["data"])
|
||||
elif file_path.is_dir():
|
||||
logger.info(f"Loading data from directory: {file_path}")
|
||||
return data_list
|
||||
|
||||
def _predict_loader(self, file_path: Path) -> BaseLoader:
|
||||
try:
|
||||
data_type = detect_datatype(str(file_path))
|
||||
config = AddConfig()
|
||||
return DataFormatter(data_type=data_type, config=config)._get_loader(
|
||||
data_type=data_type, config=config.loader, loader=None
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(f"Error processing {file_path}: {e}")
|
||||
return TextFileLoader()
|
||||
152
embedchain/embedchain/loaders/discord.py
Normal file
152
embedchain/embedchain/loaders/discord.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DiscordLoader(BaseLoader):
|
||||
"""
|
||||
Load data from a Discord Channel ID.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not os.environ.get("DISCORD_TOKEN"):
|
||||
raise ValueError("DISCORD_TOKEN is not set")
|
||||
|
||||
self.token = os.environ.get("DISCORD_TOKEN")
|
||||
|
||||
@staticmethod
|
||||
def _format_message(message):
|
||||
return {
|
||||
"message_id": message.id,
|
||||
"content": message.content,
|
||||
"author": {
|
||||
"id": message.author.id,
|
||||
"name": message.author.name,
|
||||
"discriminator": message.author.discriminator,
|
||||
},
|
||||
"created_at": message.created_at.isoformat(),
|
||||
"attachments": [
|
||||
{
|
||||
"id": attachment.id,
|
||||
"filename": attachment.filename,
|
||||
"size": attachment.size,
|
||||
"url": attachment.url,
|
||||
"proxy_url": attachment.proxy_url,
|
||||
"height": attachment.height,
|
||||
"width": attachment.width,
|
||||
}
|
||||
for attachment in message.attachments
|
||||
],
|
||||
"embeds": [
|
||||
{
|
||||
"title": embed.title,
|
||||
"type": embed.type,
|
||||
"description": embed.description,
|
||||
"url": embed.url,
|
||||
"timestamp": embed.timestamp.isoformat(),
|
||||
"color": embed.color,
|
||||
"footer": {
|
||||
"text": embed.footer.text,
|
||||
"icon_url": embed.footer.icon_url,
|
||||
"proxy_icon_url": embed.footer.proxy_icon_url,
|
||||
},
|
||||
"image": {
|
||||
"url": embed.image.url,
|
||||
"proxy_url": embed.image.proxy_url,
|
||||
"height": embed.image.height,
|
||||
"width": embed.image.width,
|
||||
},
|
||||
"thumbnail": {
|
||||
"url": embed.thumbnail.url,
|
||||
"proxy_url": embed.thumbnail.proxy_url,
|
||||
"height": embed.thumbnail.height,
|
||||
"width": embed.thumbnail.width,
|
||||
},
|
||||
"video": {
|
||||
"url": embed.video.url,
|
||||
"height": embed.video.height,
|
||||
"width": embed.video.width,
|
||||
},
|
||||
"provider": {
|
||||
"name": embed.provider.name,
|
||||
"url": embed.provider.url,
|
||||
},
|
||||
"author": {
|
||||
"name": embed.author.name,
|
||||
"url": embed.author.url,
|
||||
"icon_url": embed.author.icon_url,
|
||||
"proxy_icon_url": embed.author.proxy_icon_url,
|
||||
},
|
||||
"fields": [
|
||||
{
|
||||
"name": field.name,
|
||||
"value": field.value,
|
||||
"inline": field.inline,
|
||||
}
|
||||
for field in embed.fields
|
||||
],
|
||||
}
|
||||
for embed in message.embeds
|
||||
],
|
||||
}
|
||||
|
||||
def load_data(self, channel_id: str):
|
||||
"""Load data from a Discord Channel ID."""
|
||||
import discord
|
||||
|
||||
messages = []
|
||||
|
||||
class DiscordClient(discord.Client):
|
||||
async def on_ready(self) -> None:
|
||||
logger.info("Logged on as {0}!".format(self.user))
|
||||
try:
|
||||
channel = self.get_channel(int(channel_id))
|
||||
if not isinstance(channel, discord.TextChannel):
|
||||
raise ValueError(
|
||||
f"Channel {channel_id} is not a text channel. " "Only text channels are supported for now."
|
||||
)
|
||||
threads = {}
|
||||
|
||||
for thread in channel.threads:
|
||||
threads[thread.id] = thread
|
||||
|
||||
async for message in channel.history(limit=None):
|
||||
messages.append(DiscordLoader._format_message(message))
|
||||
if message.id in threads:
|
||||
async for thread_message in threads[message.id].history(limit=None):
|
||||
messages.append(DiscordLoader._format_message(thread_message))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await self.close()
|
||||
finally:
|
||||
await self.close()
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = DiscordClient(intents=intents)
|
||||
client.run(self.token)
|
||||
|
||||
metadata = {
|
||||
"url": channel_id,
|
||||
}
|
||||
|
||||
messages = str(messages)
|
||||
|
||||
doc_id = hashlib.sha256((messages + channel_id).encode()).hexdigest()
|
||||
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": messages,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
79
embedchain/embedchain/loaders/discourse.py
Normal file
79
embedchain/embedchain/loaders/discourse.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiscourseLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
"DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
|
||||
)
|
||||
|
||||
self.domain = config.get("domain")
|
||||
if not self.domain:
|
||||
raise ValueError(
|
||||
"DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
|
||||
)
|
||||
|
||||
def _check_query(self, query):
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError(
|
||||
"DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
|
||||
)
|
||||
|
||||
def _load_post(self, post_id):
|
||||
post_url = f"{self.domain}posts/{post_id}.json"
|
||||
response = requests.get(post_url)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load post {post_id}: {e}")
|
||||
return
|
||||
response_data = response.json()
|
||||
post_contents = clean_string(response_data.get("raw"))
|
||||
metadata = {
|
||||
"url": post_url,
|
||||
"created_at": response_data.get("created_at", ""),
|
||||
"username": response_data.get("username", ""),
|
||||
"topic_slug": response_data.get("topic_slug", ""),
|
||||
"score": response_data.get("score", ""),
|
||||
}
|
||||
data = {
|
||||
"content": post_contents,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
return data
|
||||
|
||||
def load_data(self, query):
|
||||
self._check_query(query)
|
||||
data = []
|
||||
data_contents = []
|
||||
logger.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
|
||||
search_url = f"{self.domain}search.json?q={query}"
|
||||
response = requests.get(search_url)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to search query {query}: {e}")
|
||||
response_data = response.json()
|
||||
post_ids = response_data.get("grouped_search_result").get("post_ids")
|
||||
for id in post_ids:
|
||||
post_data = self._load_post(id)
|
||||
if post_data:
|
||||
data.append(post_data)
|
||||
data_contents.append(post_data.get("content"))
|
||||
# Sleep for 0.4 sec, to avoid rate limiting. Check `https://meta.discourse.org/t/api-rate-limits/208405/6`
|
||||
time.sleep(0.4)
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
|
||||
response_data = {"doc_id": doc_id, "data": data}
|
||||
return response_data
|
||||
119
embedchain/embedchain/loaders/docs_site_loader.py
Normal file
119
embedchain/embedchain/loaders/docs_site_loader.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"DocsSite requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
|
||||
) from None
|
||||
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
def __init__(self):
|
||||
self.visited_links = set()
|
||||
|
||||
def _get_child_links_recursive(self, url):
|
||||
if url in self.visited_links:
|
||||
return
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
current_path = parsed_url.path
|
||||
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
logger.info(f"Failed to fetch the website: {response.status_code}")
|
||||
return
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
all_links = (link.get("href") for link in soup.find_all("a", href=True))
|
||||
|
||||
child_links = (link for link in all_links if link.startswith(current_path) and link != current_path)
|
||||
|
||||
absolute_paths = set(urljoin(base_url, link) for link in child_links)
|
||||
|
||||
self.visited_links.update(absolute_paths)
|
||||
|
||||
[self._get_child_links_recursive(link) for link in absolute_paths if link not in self.visited_links]
|
||||
|
||||
def _get_all_urls(self, url):
|
||||
self.visited_links = set()
|
||||
self._get_child_links_recursive(url)
|
||||
urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
|
||||
return urls
|
||||
|
||||
@staticmethod
|
||||
def _load_data_from_url(url: str) -> list:
|
||||
response = requests.get(url)
|
||||
if response.status_code != 200:
|
||||
logger.info(f"Failed to fetch the website: {response.status_code}")
|
||||
return []
|
||||
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
selectors = [
|
||||
"article.bd-article",
|
||||
'article[role="main"]',
|
||||
"div.md-content",
|
||||
'div[role="main"]',
|
||||
"div.container",
|
||||
"div.section",
|
||||
"article",
|
||||
"main",
|
||||
]
|
||||
|
||||
output = []
|
||||
for selector in selectors:
|
||||
element = soup.select_one(selector)
|
||||
if element:
|
||||
content = element.prettify()
|
||||
break
|
||||
else:
|
||||
content = soup.get_text()
|
||||
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
ignored_tags = [
|
||||
"nav",
|
||||
"aside",
|
||||
"form",
|
||||
"header",
|
||||
"noscript",
|
||||
"svg",
|
||||
"canvas",
|
||||
"footer",
|
||||
"script",
|
||||
"style",
|
||||
]
|
||||
for tag in soup(ignored_tags):
|
||||
tag.decompose()
|
||||
|
||||
content = " ".join(soup.stripped_strings)
|
||||
output.append(
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": {"url": url},
|
||||
}
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def load_data(self, url):
|
||||
all_urls = self._get_all_urls(url)
|
||||
output = []
|
||||
for u in all_urls:
|
||||
output.extend(self._load_data_from_url(u))
|
||||
doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": output,
|
||||
}
|
||||
26
embedchain/embedchain/loaders/docx_file.py
Normal file
26
embedchain/embedchain/loaders/docx_file.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
from langchain_community.document_loaders import Docx2txtLoader
|
||||
except ImportError:
|
||||
raise ImportError("Docx file requires extra dependencies. Install with `pip install docx2txt==0.8`") from None
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocxFileLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a .docx file."""
|
||||
loader = Docx2txtLoader(url)
|
||||
output = []
|
||||
data = loader.load()
|
||||
content = data[0].page_content
|
||||
metadata = data[0].metadata
|
||||
metadata["url"] = "local"
|
||||
output.append({"content": content, "meta_data": metadata})
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": output,
|
||||
}
|
||||
79
embedchain/embedchain/loaders/dropbox.py
Normal file
79
embedchain/embedchain/loaders/dropbox.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
from dropbox.files import FileMetadata
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.directory_loader import DirectoryLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DropboxLoader(BaseLoader):
|
||||
def __init__(self):
|
||||
access_token = os.environ.get("DROPBOX_ACCESS_TOKEN")
|
||||
if not access_token:
|
||||
raise ValueError("Please set the `DROPBOX_ACCESS_TOKEN` environment variable.")
|
||||
try:
|
||||
from dropbox import Dropbox, exceptions
|
||||
except ImportError:
|
||||
raise ImportError("Dropbox requires extra dependencies. Install with `pip install dropbox==11.36.2`")
|
||||
|
||||
try:
|
||||
dbx = Dropbox(access_token)
|
||||
dbx.users_get_current_account()
|
||||
self.dbx = dbx
|
||||
except exceptions.AuthError as ex:
|
||||
raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex
|
||||
|
||||
def _download_folder(self, path: str, local_root: str) -> list[FileMetadata]:
|
||||
"""Download a folder from Dropbox and save it preserving the directory structure."""
|
||||
entries = self.dbx.files_list_folder(path).entries
|
||||
for entry in entries:
|
||||
local_path = os.path.join(local_root, entry.name)
|
||||
if isinstance(entry, FileMetadata):
|
||||
self.dbx.files_download_to_file(local_path, f"{path}/{entry.name}")
|
||||
else:
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
self._download_folder(f"{path}/{entry.name}", local_path)
|
||||
return entries
|
||||
|
||||
def _generate_dir_id_from_all_paths(self, path: str) -> str:
|
||||
"""Generate a unique ID for a directory based on all of its paths."""
|
||||
entries = self.dbx.files_list_folder(path).entries
|
||||
paths = [f"{path}/{entry.name}" for entry in entries]
|
||||
return hashlib.sha256("".join(paths).encode()).hexdigest()
|
||||
|
||||
def load_data(self, path: str):
|
||||
"""Load data from a Dropbox URL, preserving the folder structure."""
|
||||
root_dir = f"dropbox_{self._generate_dir_id_from_all_paths(path)}"
|
||||
os.makedirs(root_dir, exist_ok=True)
|
||||
|
||||
for entry in self.dbx.files_list_folder(path).entries:
|
||||
local_path = os.path.join(root_dir, entry.name)
|
||||
if isinstance(entry, FileMetadata):
|
||||
self.dbx.files_download_to_file(local_path, f"{path}/{entry.name}")
|
||||
else:
|
||||
os.makedirs(local_path, exist_ok=True)
|
||||
self._download_folder(f"{path}/{entry.name}", local_path)
|
||||
|
||||
dir_loader = DirectoryLoader()
|
||||
data = dir_loader.load_data(root_dir)["data"]
|
||||
|
||||
# Clean up
|
||||
self._clean_directory(root_dir)
|
||||
|
||||
return {
|
||||
"doc_id": hashlib.sha256(path.encode()).hexdigest(),
|
||||
"data": data,
|
||||
}
|
||||
|
||||
def _clean_directory(self, dir_path):
|
||||
"""Recursively delete a directory and its contents."""
|
||||
for item in os.listdir(dir_path):
|
||||
item_path = os.path.join(dir_path, item)
|
||||
if os.path.isdir(item_path):
|
||||
self._clean_directory(item_path)
|
||||
else:
|
||||
os.remove(item_path)
|
||||
os.rmdir(dir_path)
|
||||
41
embedchain/embedchain/loaders/excel_file.py
Normal file
41
embedchain/embedchain/loaders/excel_file.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import hashlib
|
||||
import importlib.util
|
||||
|
||||
try:
|
||||
import unstructured # noqa: F401
|
||||
from langchain_community.document_loaders import UnstructuredExcelLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Excel file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`'
|
||||
) from None
|
||||
|
||||
if importlib.util.find_spec("openpyxl") is None and importlib.util.find_spec("xlrd") is None:
|
||||
raise ImportError("Excel file requires extra dependencies. Install with `pip install openpyxl xlrd`") from None
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ExcelFileLoader(BaseLoader):
|
||||
def load_data(self, excel_url):
|
||||
"""Load data from a Excel file."""
|
||||
loader = UnstructuredExcelLoader(excel_url)
|
||||
pages = loader.load_and_split()
|
||||
|
||||
data = []
|
||||
for page in pages:
|
||||
content = page.page_content
|
||||
content = clean_string(content)
|
||||
|
||||
metadata = page.metadata
|
||||
metadata["url"] = excel_url
|
||||
|
||||
data.append({"content": content, "meta_data": metadata})
|
||||
|
||||
doc_id = hashlib.sha256((content + excel_url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
312
embedchain/embedchain/loaders/github.py
Normal file
312
embedchain/embedchain/loaders/github.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import shlex
|
||||
from typing import Any, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
GITHUB_URL = "https://github.com"
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
|
||||
VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"])
|
||||
|
||||
|
||||
class GithubLoader(BaseLoader):
|
||||
"""Load data from GitHub search query."""
|
||||
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
|
||||
)
|
||||
|
||||
try:
|
||||
from github import Github
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"GithubLoader requires extra dependencies. \
|
||||
Install with `pip install gitpython==3.1.38 PyGithub==1.59.1`"
|
||||
) from e
|
||||
|
||||
self.config = config
|
||||
token = config.get("token")
|
||||
if not token:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
|
||||
)
|
||||
|
||||
try:
|
||||
self.client = Github(token)
|
||||
except Exception as e:
|
||||
logging.error(f"GithubLoader failed to initialize client: {e}")
|
||||
self.client = None
|
||||
|
||||
def _github_search_code(self, query: str):
|
||||
"""Search GitHub code."""
|
||||
data = []
|
||||
results = self.client.search_code(query)
|
||||
for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
|
||||
url = result.html_url
|
||||
logging.info(f"Added data from url: {url}")
|
||||
content = result.decoded_content.decode("utf-8")
|
||||
metadata = {
|
||||
"url": url,
|
||||
}
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(content),
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]:
|
||||
"""Get file contents from Repo"""
|
||||
data = []
|
||||
|
||||
repo = self.client.get_repo(repo_name)
|
||||
repo_contents = repo.get_contents("")
|
||||
|
||||
if branch_name:
|
||||
repo_contents = repo.get_contents("", ref=branch_name)
|
||||
if file_path:
|
||||
repo_contents = [repo.get_contents(file_path)]
|
||||
|
||||
with tqdm(desc="Loading files:", unit="item") as progress_bar:
|
||||
while repo_contents:
|
||||
file_content = repo_contents.pop(0)
|
||||
if file_content.type == "dir":
|
||||
try:
|
||||
repo_contents.extend(repo.get_contents(file_content.path))
|
||||
except Exception:
|
||||
logging.warning(f"Failed to read directory: {file_content.path}")
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
file_text = file_content.decoded_content.decode()
|
||||
except Exception:
|
||||
logging.warning(f"Failed to read file: {file_content.path}")
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
file_path = file_content.path
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(file_text),
|
||||
"meta_data": {
|
||||
"path": file_path,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
progress_bar.update(1)
|
||||
|
||||
return data
|
||||
|
||||
def _github_search_repo(self, query: str) -> list[dict]:
|
||||
"""Search GitHub repo."""
|
||||
|
||||
logging.info(f"Searching github repos with query: {query}")
|
||||
updated_query = query.split(":")[-1]
|
||||
data = self._get_github_repo_data(updated_query)
|
||||
return data
|
||||
|
||||
def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
|
||||
"""Search GitHub issues and PRs."""
|
||||
data = []
|
||||
|
||||
query = f"{query} is:{type}"
|
||||
logging.info(f"Searching github for query: {query}")
|
||||
|
||||
results = self.client.search_issues(query)
|
||||
|
||||
logging.info(f"Total results: {results.totalCount}")
|
||||
for result in tqdm(results, total=results.totalCount, desc=f"Loading {type} from github"):
|
||||
url = result.html_url
|
||||
title = result.title
|
||||
body = result.body
|
||||
if not body:
|
||||
logging.warning(f"Skipping issue because empty content for: {url}")
|
||||
continue
|
||||
labels = " ".join([label.name for label in result.labels])
|
||||
issue_comments = result.get_comments()
|
||||
comments = []
|
||||
comments_created_at = []
|
||||
for comment in issue_comments:
|
||||
comments_created_at.append(str(comment.created_at))
|
||||
comments.append(f"{comment.user.name}:{comment.body}")
|
||||
content = "\n".join([title, labels, body, *comments])
|
||||
metadata = {
|
||||
"url": url,
|
||||
"created_at": str(result.created_at),
|
||||
"comments_created_at": " ".join(comments_created_at),
|
||||
}
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(content),
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
# need to test more for discussion
|
||||
def _github_search_discussions(self, query: str):
|
||||
"""Search GitHub discussions."""
|
||||
data = []
|
||||
|
||||
query = f"{query} is:discussion"
|
||||
logging.info(f"Searching github repo for query: {query}")
|
||||
repos_results = self.client.search_repositories(query)
|
||||
logging.info(f"Total repos found: {repos_results.totalCount}")
|
||||
for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
|
||||
teams = repo_result.get_teams()
|
||||
for team in teams:
|
||||
team_discussions = team.get_discussions()
|
||||
for discussion in team_discussions:
|
||||
url = discussion.html_url
|
||||
title = discussion.title
|
||||
body = discussion.body
|
||||
if not body:
|
||||
logging.warning(f"Skipping discussion because empty content for: {url}")
|
||||
continue
|
||||
comments = []
|
||||
comments_created_at = []
|
||||
print("Discussion comments: ", discussion.comments_url)
|
||||
content = "\n".join([title, body, *comments])
|
||||
metadata = {
|
||||
"url": url,
|
||||
"created_at": str(discussion.created_at),
|
||||
"comments_created_at": " ".join(comments_created_at),
|
||||
}
|
||||
data.append(
|
||||
{
|
||||
"content": clean_string(content),
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
def _get_github_repo_branch(self, query: str, type: str) -> list[dict]:
|
||||
"""Get file contents for specific branch"""
|
||||
|
||||
logging.info(f"Searching github repo for query: {query} is:{type}")
|
||||
pattern = r"repo:(\S+) name:(\S+)"
|
||||
match = re.search(pattern, query)
|
||||
|
||||
if match:
|
||||
repo_name = match.group(1)
|
||||
branch_name = match.group(2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Repository name and Branch name not found, instead found this \
|
||||
Repo: {repo_name}, Branch: {branch_name}"
|
||||
)
|
||||
|
||||
data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name)
|
||||
return data
|
||||
|
||||
def _get_github_repo_file(self, query: str, type: str) -> list[dict]:
|
||||
"""Get specific file content"""
|
||||
|
||||
logging.info(f"Searching github repo for query: {query} is:{type}")
|
||||
pattern = r"repo:(\S+) path:(\S+)"
|
||||
match = re.search(pattern, query)
|
||||
|
||||
if match:
|
||||
repo_name = match.group(1)
|
||||
file_path = match.group(2)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}"
|
||||
)
|
||||
|
||||
data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path)
|
||||
return data
|
||||
|
||||
def _search_github_data(self, search_type: str, query: str):
|
||||
"""Search github data."""
|
||||
if search_type == "code":
|
||||
data = self._github_search_code(query)
|
||||
elif search_type == "repo":
|
||||
data = self._github_search_repo(query)
|
||||
elif search_type == "issue":
|
||||
data = self._github_search_issues_and_pr(query, search_type)
|
||||
elif search_type == "pr":
|
||||
data = self._github_search_issues_and_pr(query, search_type)
|
||||
elif search_type == "branch":
|
||||
data = self._get_github_repo_branch(query, search_type)
|
||||
elif search_type == "file":
|
||||
data = self._get_github_repo_file(query, search_type)
|
||||
elif search_type == "discussion":
|
||||
raise ValueError("GithubLoader does not support searching discussions yet.")
|
||||
else:
|
||||
raise NotImplementedError(f"{search_type} not supported")
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_github_query(query: str):
|
||||
"""Check if query is valid and return search types and valid GitHub query."""
|
||||
query_terms = shlex.split(query)
|
||||
# query must provide repo to load data from
|
||||
if len(query_terms) < 1 or "repo:" not in query:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a search query with `repo:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
|
||||
)
|
||||
|
||||
github_query = []
|
||||
types = set()
|
||||
type_pattern = r"type:([a-zA-Z,]+)"
|
||||
for term in query_terms:
|
||||
term_match = re.search(type_pattern, term)
|
||||
if term_match:
|
||||
search_types = term_match.group(1).split(",")
|
||||
types.update(search_types)
|
||||
else:
|
||||
github_query.append(term)
|
||||
|
||||
# query must provide search type
|
||||
if len(types) == 0:
|
||||
raise ValueError(
|
||||
"GithubLoader requires a search query with `type:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
|
||||
)
|
||||
|
||||
for search_type in search_types:
|
||||
if search_type not in VALID_SEARCH_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid search type: {search_type}. Valid types are: {', '.join(VALID_SEARCH_TYPES)}"
|
||||
)
|
||||
|
||||
query = " ".join(github_query)
|
||||
|
||||
return types, query
|
||||
|
||||
def load_data(self, search_query: str, max_results: int = 1000):
|
||||
"""Load data from GitHub search query."""
|
||||
|
||||
if not self.client:
|
||||
raise ValueError(
|
||||
"GithubLoader client is not initialized, data will not be loaded. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
|
||||
)
|
||||
|
||||
search_types, query = self._get_valid_github_query(search_query)
|
||||
logging.info(f"Searching github for query: {query}, with types: {', '.join(search_types)}")
|
||||
|
||||
data = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures_map = executor.map(self._search_github_data, search_types, [query] * len(search_types))
|
||||
for search_data in tqdm(futures_map, total=len(search_types), desc="Searching data from github"):
|
||||
data.extend(search_data)
|
||||
|
||||
return {
|
||||
"doc_id": hashlib.sha256(query.encode()).hexdigest(),
|
||||
"data": data,
|
||||
}
|
||||
144
embedchain/embedchain/loaders/gmail.py
Normal file
144
embedchain/embedchain/loaders/gmail.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from email import message_from_bytes
|
||||
from email.utils import parsedate_to_datetime
|
||||
from textwrap import dedent
|
||||
from typing import Optional
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
try:
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from googleapiclient.discovery import build
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Gmail requires extra dependencies. Install with `pip install --upgrade "embedchain[gmail]"`'
|
||||
) from None
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GmailReader:
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
|
||||
def __init__(self, query: str, service=None, results_per_page: int = 10):
|
||||
self.query = query
|
||||
self.service = service or self._initialize_service()
|
||||
self.results_per_page = results_per_page
|
||||
|
||||
@staticmethod
|
||||
def _initialize_service():
|
||||
credentials = GmailReader._get_credentials()
|
||||
return build("gmail", "v1", credentials=credentials)
|
||||
|
||||
@staticmethod
|
||||
def _get_credentials():
|
||||
if not os.path.exists("credentials.json"):
|
||||
raise FileNotFoundError("Missing 'credentials.json'. Download it from your Google Developer account.")
|
||||
|
||||
creds = (
|
||||
Credentials.from_authorized_user_file("token.json", GmailReader.SCOPES)
|
||||
if os.path.exists("token.json")
|
||||
else None
|
||||
)
|
||||
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request())
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file("credentials.json", GmailReader.SCOPES)
|
||||
creds = flow.run_local_server(port=8080)
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
return creds
|
||||
|
||||
def load_emails(self) -> list[dict]:
|
||||
response = self.service.users().messages().list(userId="me", q=self.query).execute()
|
||||
messages = response.get("messages", [])
|
||||
|
||||
return [self._parse_email(self._get_email(message["id"])) for message in messages]
|
||||
|
||||
def _get_email(self, message_id: str):
|
||||
raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute()
|
||||
return base64.urlsafe_b64decode(raw_message["raw"])
|
||||
|
||||
def _parse_email(self, raw_email) -> dict:
|
||||
mime_msg = message_from_bytes(raw_email)
|
||||
return {
|
||||
"subject": self._get_header(mime_msg, "Subject"),
|
||||
"from": self._get_header(mime_msg, "From"),
|
||||
"to": self._get_header(mime_msg, "To"),
|
||||
"date": self._format_date(mime_msg),
|
||||
"body": self._get_body(mime_msg),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_header(mime_msg, header_name: str) -> str:
|
||||
return mime_msg.get(header_name, "")
|
||||
|
||||
@staticmethod
|
||||
def _format_date(mime_msg) -> Optional[str]:
|
||||
date_header = GmailReader._get_header(mime_msg, "Date")
|
||||
return parsedate_to_datetime(date_header).isoformat() if date_header else None
|
||||
|
||||
@staticmethod
|
||||
def _get_body(mime_msg) -> str:
|
||||
def decode_payload(part):
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
try:
|
||||
return part.get_payload(decode=True).decode(charset)
|
||||
except UnicodeDecodeError:
|
||||
return part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
|
||||
if mime_msg.is_multipart():
|
||||
for part in mime_msg.walk():
|
||||
ctype = part.get_content_type()
|
||||
cdispo = str(part.get("Content-Disposition"))
|
||||
|
||||
if ctype == "text/plain" and "attachment" not in cdispo:
|
||||
return decode_payload(part)
|
||||
elif ctype == "text/html":
|
||||
return decode_payload(part)
|
||||
else:
|
||||
return decode_payload(mime_msg)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class GmailLoader(BaseLoader):
|
||||
def load_data(self, query: str):
|
||||
reader = GmailReader(query=query)
|
||||
emails = reader.load_emails()
|
||||
logger.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
|
||||
|
||||
data = []
|
||||
for email in emails:
|
||||
content = self._process_email(email)
|
||||
data.append({"content": content, "meta_data": email})
|
||||
|
||||
return {"doc_id": self._generate_doc_id(query, data), "data": data}
|
||||
|
||||
@staticmethod
|
||||
def _process_email(email: dict) -> str:
|
||||
content = BeautifulSoup(email["body"], "html.parser").get_text()
|
||||
content = clean_string(content)
|
||||
return dedent(
|
||||
f"""
|
||||
Email from '{email['from']}' to '{email['to']}'
|
||||
Subject: {email['subject']}
|
||||
Date: {email['date']}
|
||||
Content: {content}
|
||||
"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_doc_id(query: str, data: list[dict]) -> str:
|
||||
content_strings = [email["content"] for email in data]
|
||||
return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest()
|
||||
62
embedchain/embedchain/loaders/google_drive.py
Normal file
62
embedchain/embedchain/loaders/google_drive.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
try:
|
||||
from googleapiclient.errors import HttpError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
|
||||
) from None
|
||||
|
||||
from langchain_community.document_loaders import GoogleDriveLoader as Loader
|
||||
|
||||
try:
|
||||
import unstructured # noqa: F401
|
||||
from langchain_community.document_loaders import UnstructuredFileIOLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
|
||||
) from None
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleDriveLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _get_drive_id_from_url(url: str):
|
||||
regex = r"^https:\/\/drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
|
||||
if re.match(regex, url):
|
||||
return url.split("/")[-1]
|
||||
raise ValueError(
|
||||
f"The url provided {url} does not match a google drive folder url. Example drive url: "
|
||||
f"https://drive.google.com/drive/u/0/folders/xxxx"
|
||||
)
|
||||
|
||||
def load_data(self, url: str):
|
||||
"""Load data from a Google drive folder."""
|
||||
folder_id: str = self._get_drive_id_from_url(url)
|
||||
|
||||
try:
|
||||
loader = Loader(
|
||||
folder_id=folder_id,
|
||||
recursive=True,
|
||||
file_loader_cls=UnstructuredFileIOLoader,
|
||||
)
|
||||
|
||||
data = []
|
||||
all_content = []
|
||||
|
||||
docs = loader.load()
|
||||
for doc in docs:
|
||||
all_content.append(doc.page_content)
|
||||
# renames source to url for later use.
|
||||
doc.metadata["url"] = doc.metadata.pop("source")
|
||||
data.append({"content": doc.page_content, "meta_data": doc.metadata})
|
||||
|
||||
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": data}
|
||||
|
||||
except HttpError:
|
||||
raise FileNotFoundError("Unable to locate folder or files, check provided drive URL and try again")
|
||||
50
embedchain/embedchain/loaders/image.py
Normal file
50
embedchain/embedchain/loaders/image.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
DESCRIBE_IMAGE_PROMPT = "Describe the image:"
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ImageLoader(BaseLoader):
|
||||
def __init__(self, max_tokens: int = 500, api_key: str = None, prompt: str = None):
|
||||
super().__init__()
|
||||
self.custom_prompt = prompt or DESCRIBE_IMAGE_PROMPT
|
||||
self.max_tokens = max_tokens
|
||||
self.api_key = api_key or os.environ["OPENAI_API_KEY"]
|
||||
self.client = OpenAI(api_key=self.api_key)
|
||||
|
||||
@staticmethod
|
||||
def _encode_image(image_path: str):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def _create_completion_request(self, content: str):
|
||||
return self.client.chat.completions.create(
|
||||
model="gpt-4o", messages=[{"role": "user", "content": content}], max_tokens=self.max_tokens
|
||||
)
|
||||
|
||||
def _process_url(self, url: str):
|
||||
if url.startswith("http"):
|
||||
return [{"type": "text", "text": self.custom_prompt}, {"type": "image_url", "image_url": {"url": url}}]
|
||||
elif Path(url).is_file():
|
||||
extension = Path(url).suffix.lstrip(".")
|
||||
encoded_image = self._encode_image(url)
|
||||
image_data = f"data:image/{extension};base64,{encoded_image}"
|
||||
return [{"type": "text", "text": self.custom_prompt}, {"type": "image", "image_url": {"url": image_data}}]
|
||||
else:
|
||||
raise ValueError(f"Invalid URL or file path: {url}")
|
||||
|
||||
def load_data(self, url: str):
|
||||
content = self._process_url(url)
|
||||
response = self._create_completion_request(content)
|
||||
content = response.choices[0].message.content
|
||||
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": [{"content": content, "meta_data": {"url": url, "type": "image"}}]}
|
||||
93
embedchain/embedchain/loaders/json.py
Normal file
93
embedchain/embedchain/loaders/json.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string, is_valid_json_string
|
||||
|
||||
|
||||
class JSONReader:
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the JSONReader."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_data(json_data: Union[dict, str]) -> list[str]:
|
||||
"""Load data from a JSON structure.
|
||||
|
||||
Args:
|
||||
json_data (Union[dict, str]): The JSON data to load.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of strings representing the leaf nodes of the JSON.
|
||||
"""
|
||||
if isinstance(json_data, str):
|
||||
json_data = json.loads(json_data)
|
||||
else:
|
||||
json_data = json_data
|
||||
|
||||
json_output = json.dumps(json_data, indent=0)
|
||||
lines = json_output.split("\n")
|
||||
useful_lines = [line for line in lines if not re.match(r"^[{}\[\],]*$", line)]
|
||||
return ["\n".join(useful_lines)]
|
||||
|
||||
|
||||
VALID_URL_PATTERN = (
|
||||
"^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$"
|
||||
)
|
||||
|
||||
|
||||
class JSONLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _check_content(content):
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
"Invaid content input. \
|
||||
If you want to upload (list, dict, etc.), do \
|
||||
`json.dump(data, indent=0)` and add the stringified JSON. \
|
||||
Check - `https://docs.embedchain.ai/data-sources/json`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_data(content):
|
||||
"""Load a json file. Each data point is a key value pair."""
|
||||
|
||||
JSONLoader._check_content(content)
|
||||
loader = JSONReader()
|
||||
|
||||
data = []
|
||||
data_content = []
|
||||
|
||||
content_url_str = content
|
||||
|
||||
if os.path.isfile(content):
|
||||
with open(content, "r", encoding="utf-8") as json_file:
|
||||
json_data = json.load(json_file)
|
||||
elif re.match(VALID_URL_PATTERN, content):
|
||||
response = requests.get(content)
|
||||
if response.status_code == 200:
|
||||
json_data = response.json()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Loading data from the given url: {content} failed. \
|
||||
Make sure the url is working."
|
||||
)
|
||||
elif is_valid_json_string(content):
|
||||
json_data = content
|
||||
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
|
||||
else:
|
||||
raise ValueError(f"Invalid content to load json data from: {content}")
|
||||
|
||||
docs = loader.load_data(json_data)
|
||||
for doc in docs:
|
||||
text = doc if isinstance(doc, str) else doc["text"]
|
||||
doc_content = clean_string(text)
|
||||
data.append({"content": doc_content, "meta_data": {"url": content_url_str}})
|
||||
data_content.append(doc_content)
|
||||
|
||||
doc_id = hashlib.sha256((content_url_str + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": data}
|
||||
24
embedchain/embedchain/loaders/local_qna_pair.py
Normal file
24
embedchain/embedchain/loaders/local_qna_pair.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class LocalQnaPairLoader(BaseLoader):
|
||||
def load_data(self, content):
|
||||
"""Load data from a local QnA pair."""
|
||||
question, answer = content
|
||||
content = f"Q: {question}\nA: {answer}"
|
||||
url = "local"
|
||||
metadata = {"url": url, "question": question}
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
24
embedchain/embedchain/loaders/local_text.py
Normal file
24
embedchain/embedchain/loaders/local_text.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class LocalTextLoader(BaseLoader):
|
||||
def load_data(self, content):
|
||||
"""Load data from a local text file."""
|
||||
url = "local"
|
||||
metadata = {
|
||||
"url": url,
|
||||
}
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
25
embedchain/embedchain/loaders/mdx.py
Normal file
25
embedchain/embedchain/loaders/mdx.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class MdxLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a mdx file."""
|
||||
with open(url, "r", encoding="utf-8") as infile:
|
||||
content = infile.read()
|
||||
metadata = {
|
||||
"url": url,
|
||||
}
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
67
embedchain/embedchain/loaders/mysql.py
Normal file
67
embedchain/embedchain/loaders/mysql.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MySQLLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]]):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(
|
||||
f"Invalid sql config: {config}.",
|
||||
"Provide the correct config, refer `https://docs.embedchain.ai/data-sources/mysql`.",
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
self._setup_loader(config=config)
|
||||
|
||||
def _setup_loader(self, config: dict[str, Any]):
|
||||
try:
|
||||
import mysql.connector as sqlconnector
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import required packages for MySQL loader. Run `pip install --upgrade 'embedchain[mysql]'`." # noqa: E501
|
||||
) from e
|
||||
|
||||
try:
|
||||
self.connection = sqlconnector.connection.MySQLConnection(**config)
|
||||
self.cursor = self.connection.cursor()
|
||||
except (sqlconnector.Error, IOError) as err:
|
||||
logger.info(f"Connection failed: {err}")
|
||||
raise ValueError(
|
||||
f"Unable to connect with the given config: {config}.",
|
||||
"Please provide the correct configuration to load data from you MySQL DB. \
|
||||
Refer `https://docs.embedchain.ai/data-sources/mysql`.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_query(query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid mysql query: {query}",
|
||||
"Provide the valid query to add from mysql, \
|
||||
make sure you are following `https://docs.embedchain.ai/data-sources/mysql`",
|
||||
)
|
||||
|
||||
def load_data(self, query):
|
||||
self._check_query(query=query)
|
||||
data = []
|
||||
data_content = []
|
||||
self.cursor.execute(query)
|
||||
rows = self.cursor.fetchall()
|
||||
for row in rows:
|
||||
doc_content = clean_string(str(row))
|
||||
data.append({"content": doc_content, "meta_data": {"url": query}})
|
||||
data_content.append(doc_content)
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
121
embedchain/embedchain/loaders/notion.py
Normal file
121
embedchain/embedchain/loaders/notion.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotionDocument:
|
||||
"""
|
||||
A simple Document class to hold the text and additional information of a page.
|
||||
"""
|
||||
|
||||
def __init__(self, text: str, extra_info: dict[str, Any]):
|
||||
self.text = text
|
||||
self.extra_info = extra_info
|
||||
|
||||
|
||||
class NotionPageLoader:
|
||||
"""
|
||||
Notion Page Loader.
|
||||
Reads a set of Notion pages.
|
||||
"""
|
||||
|
||||
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
||||
|
||||
def __init__(self, integration_token: Optional[str] = None) -> None:
|
||||
"""Initialize with Notion integration token."""
|
||||
if integration_token is None:
|
||||
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
|
||||
if integration_token is None:
|
||||
raise ValueError(
|
||||
"Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`."
|
||||
)
|
||||
self.token = integration_token
|
||||
self.headers = {
|
||||
"Authorization": "Bearer " + self.token,
|
||||
"Content-Type": "application/json",
|
||||
"Notion-Version": "2022-06-28",
|
||||
}
|
||||
|
||||
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
||||
"""Read a block from Notion."""
|
||||
done = False
|
||||
result_lines_arr = []
|
||||
cur_block_id = block_id
|
||||
while not done:
|
||||
block_url = self.BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
||||
res = requests.get(block_url, headers=self.headers)
|
||||
data = res.json()
|
||||
|
||||
for result in data["results"]:
|
||||
result_type = result["type"]
|
||||
result_obj = result[result_type]
|
||||
|
||||
cur_result_text_arr = []
|
||||
if "rich_text" in result_obj:
|
||||
for rich_text in result_obj["rich_text"]:
|
||||
if "text" in rich_text:
|
||||
text = rich_text["text"]["content"]
|
||||
prefix = "\t" * num_tabs
|
||||
cur_result_text_arr.append(prefix + text)
|
||||
|
||||
result_block_id = result["id"]
|
||||
has_children = result["has_children"]
|
||||
if has_children:
|
||||
children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1)
|
||||
cur_result_text_arr.append(children_text)
|
||||
|
||||
cur_result_text = "\n".join(cur_result_text_arr)
|
||||
result_lines_arr.append(cur_result_text)
|
||||
|
||||
if data["next_cursor"] is None:
|
||||
done = True
|
||||
else:
|
||||
cur_block_id = data["next_cursor"]
|
||||
|
||||
result_lines = "\n".join(result_lines_arr)
|
||||
return result_lines
|
||||
|
||||
def load_data(self, page_ids: list[str]) -> list[NotionDocument]:
|
||||
"""Load data from the given list of page IDs."""
|
||||
docs = []
|
||||
for page_id in page_ids:
|
||||
page_text = self._read_block(page_id)
|
||||
docs.append(NotionDocument(text=page_text, extra_info={"page_id": page_id}))
|
||||
return docs
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class NotionLoader(BaseLoader):
|
||||
def load_data(self, source):
|
||||
"""Load data from a Notion URL."""
|
||||
|
||||
id = source[-32:]
|
||||
formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}"
|
||||
logger.debug(f"Extracted notion page id as: {formatted_id}")
|
||||
|
||||
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
|
||||
reader = NotionPageLoader(integration_token=integration_token)
|
||||
documents = reader.load_data(page_ids=[formatted_id])
|
||||
|
||||
raw_text = documents[0].text
|
||||
|
||||
text = clean_string(raw_text)
|
||||
doc_id = hashlib.sha256((text + source).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": text,
|
||||
"meta_data": {"url": f"notion-{formatted_id}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
42
embedchain/embedchain/loaders/openapi.py
Normal file
42
embedchain/embedchain/loaders/openapi.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import hashlib
|
||||
from io import StringIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
class OpenAPILoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _get_file_content(content):
|
||||
url = urlparse(content)
|
||||
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
|
||||
raise ValueError("Not a valid URL.")
|
||||
|
||||
if url.scheme in ["http", "https"]:
|
||||
response = requests.get(content)
|
||||
response.raise_for_status()
|
||||
return StringIO(response.text)
|
||||
elif url.scheme == "file":
|
||||
path = url.path
|
||||
return open(path)
|
||||
else:
|
||||
return open(content)
|
||||
|
||||
@staticmethod
|
||||
def load_data(content):
|
||||
"""Load yaml file of openapi. Each pair is a document."""
|
||||
data = []
|
||||
file_path = content
|
||||
data_content = []
|
||||
with OpenAPILoader._get_file_content(content=content) as file:
|
||||
yaml_data = yaml.load(file, Loader=yaml.SafeLoader)
|
||||
for i, (key, value) in enumerate(yaml_data.items()):
|
||||
string_data = f"{key}: {value}"
|
||||
metadata = {"url": file_path, "row": i + 1}
|
||||
data.append({"content": string_data, "meta_data": metadata})
|
||||
data_content.append(string_data)
|
||||
doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": data}
|
||||
38
embedchain/embedchain/loaders/pdf_file.py
Normal file
38
embedchain/embedchain/loaders/pdf_file.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import hashlib
|
||||
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PdfFileLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a PDF file."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
|
||||
}
|
||||
loader = PyPDFLoader(url, headers=headers)
|
||||
data = []
|
||||
all_content = []
|
||||
pages = loader.load_and_split()
|
||||
if not len(pages):
|
||||
raise ValueError("No data found")
|
||||
for page in pages:
|
||||
content = page.page_content
|
||||
content = clean_string(content)
|
||||
metadata = page.metadata
|
||||
metadata["url"] = url
|
||||
data.append(
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
all_content.append(content)
|
||||
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
73
embedchain/embedchain/loaders/postgres.py
Normal file
73
embedchain/embedchain/loaders/postgres.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
if not config:
|
||||
raise ValueError(f"Must provide the valid config. Received: {config}")
|
||||
|
||||
self.connection = None
|
||||
self.cursor = None
|
||||
self._setup_loader(config=config)
|
||||
|
||||
def _setup_loader(self, config: dict[str, Any]):
|
||||
try:
|
||||
import psycopg
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import required packages. \
|
||||
Run `pip install --upgrade 'embedchain[postgres]'`"
|
||||
) from e
|
||||
|
||||
if "url" in config:
|
||||
config_info = config.get("url")
|
||||
else:
|
||||
conn_params = []
|
||||
for key, value in config.items():
|
||||
conn_params.append(f"{key}={value}")
|
||||
config_info = " ".join(conn_params)
|
||||
|
||||
logger.info(f"Connecting to postrgres sql: {config_info}")
|
||||
self.connection = psycopg.connect(conninfo=config_info)
|
||||
self.cursor = self.connection.cursor()
|
||||
|
||||
@staticmethod
|
||||
def _check_query(query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501
|
||||
)
|
||||
|
||||
def load_data(self, query):
|
||||
self._check_query(query)
|
||||
try:
|
||||
data = []
|
||||
data_content = []
|
||||
self.cursor.execute(query)
|
||||
results = self.cursor.fetchall()
|
||||
for result in results:
|
||||
doc_content = str(result)
|
||||
data.append({"content": doc_content, "meta_data": {"url": query}})
|
||||
data_content.append(doc_content)
|
||||
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load data using query={query} with: {e}")
|
||||
|
||||
def close_connection(self):
|
||||
if self.cursor:
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
52
embedchain/embedchain/loaders/rss_feed.py
Normal file
52
embedchain/embedchain/loaders/rss_feed.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class RSSFeedLoader(BaseLoader):
|
||||
"""Loader for RSS Feed."""
|
||||
|
||||
def load_data(self, url):
|
||||
"""Load data from a rss feed."""
|
||||
output = self.get_rss_content(url)
|
||||
doc_id = hashlib.sha256((str(output) + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": output,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def serialize_metadata(metadata):
|
||||
for key, value in metadata.items():
|
||||
if not isinstance(value, (str, int, float, bool)):
|
||||
metadata[key] = str(value)
|
||||
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def get_rss_content(url: str):
|
||||
try:
|
||||
from langchain_community.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"""RSSFeedLoader file requires extra dependencies.
|
||||
Install with `pip install feedparser==6.0.10 newspaper3k==0.2.8 listparser==0.19`"""
|
||||
) from None
|
||||
|
||||
output = []
|
||||
loader = LangchainRSSFeedLoader(urls=[url])
|
||||
data = loader.load()
|
||||
|
||||
for entry in data:
|
||||
metadata = RSSFeedLoader.serialize_metadata(entry.metadata)
|
||||
metadata.update({"url": url})
|
||||
output.append(
|
||||
{
|
||||
"content": entry.page_content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
return output
|
||||
79
embedchain/embedchain/loaders/sitemap.py
Normal file
79
embedchain/embedchain/loaders/sitemap.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.builder import ParserRejectedMarkup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Sitemap requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
|
||||
) from None
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SitemapLoader(BaseLoader):
|
||||
"""
|
||||
This method takes a sitemap URL or local file path as input and retrieves
|
||||
all the URLs to use the WebPageLoader to load content
|
||||
of each page.
|
||||
"""
|
||||
|
||||
def load_data(self, sitemap_source):
|
||||
output = []
|
||||
web_page_loader = WebPageLoader()
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
|
||||
}
|
||||
|
||||
if urlparse(sitemap_source).scheme in ("http", "https"):
|
||||
try:
|
||||
response = requests.get(sitemap_source, headers=headers)
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, "xml")
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error fetching sitemap from URL: {e}")
|
||||
return
|
||||
elif os.path.isfile(sitemap_source):
|
||||
with open(sitemap_source, "r") as file:
|
||||
soup = BeautifulSoup(file, "xml")
|
||||
else:
|
||||
raise ValueError("Invalid sitemap source. Please provide a valid URL or local file path.")
|
||||
|
||||
links = [link.text for link in soup.find_all("loc") if link.parent.name == "url"]
|
||||
if len(links) == 0:
|
||||
links = [link.text for link in soup.find_all("loc")]
|
||||
|
||||
doc_id = hashlib.sha256((" ".join(links) + sitemap_source).encode()).hexdigest()
|
||||
|
||||
def load_web_page(link):
|
||||
try:
|
||||
loader_data = web_page_loader.load_data(link)
|
||||
return loader_data.get("data")
|
||||
except ParserRejectedMarkup as e:
|
||||
logger.error(f"Failed to parse {link}: {e}")
|
||||
return None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_link = {executor.submit(load_web_page, link): link for link in links}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_link), total=len(links), desc="Loading pages"):
|
||||
link = future_to_link[future]
|
||||
try:
|
||||
data = future.result()
|
||||
if data:
|
||||
output.extend(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading page {link}: {e}")
|
||||
|
||||
return {"doc_id": doc_id, "data": output}
|
||||
115
embedchain/embedchain/loaders/slack.py
Normal file
115
embedchain/embedchain/loaders/slack.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
|
||||
import certifi
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
SLACK_API_BASE_URL = "https://www.slack.com/api/"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackLoader(BaseLoader):
|
||||
def __init__(self, config: Optional[dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
|
||||
self.config = config if config else {}
|
||||
|
||||
if "base_url" not in self.config:
|
||||
self.config["base_url"] = SLACK_API_BASE_URL
|
||||
|
||||
self.client = None
|
||||
self._setup_loader(self.config)
|
||||
|
||||
def _setup_loader(self, config: dict[str, Any]):
|
||||
try:
|
||||
from slack_sdk import WebClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Slack loader requires extra dependencies. \
|
||||
Install with `pip install --upgrade embedchain[slack]`"
|
||||
) from e
|
||||
|
||||
if os.getenv("SLACK_USER_TOKEN") is None:
|
||||
raise ValueError(
|
||||
"SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
|
||||
)
|
||||
|
||||
logger.info(f"Creating Slack Loader with config: {config}")
|
||||
# get slack client config params
|
||||
slack_bot_token = os.getenv("SLACK_USER_TOKEN")
|
||||
ssl_cert = ssl.create_default_context(cafile=certifi.where())
|
||||
base_url = config.get("base_url", SLACK_API_BASE_URL)
|
||||
headers = config.get("headers")
|
||||
# for Org-Wide App
|
||||
team_id = config.get("team_id")
|
||||
|
||||
self.client = WebClient(
|
||||
token=slack_bot_token,
|
||||
base_url=base_url,
|
||||
ssl=ssl_cert,
|
||||
headers=headers,
|
||||
team_id=team_id,
|
||||
)
|
||||
logger.info("Slack Loader setup successful!")
|
||||
|
||||
@staticmethod
|
||||
def _check_query(query):
|
||||
if not isinstance(query, str):
|
||||
raise ValueError(
|
||||
f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
|
||||
)
|
||||
|
||||
def load_data(self, query):
|
||||
self._check_query(query)
|
||||
try:
|
||||
data = []
|
||||
data_content = []
|
||||
|
||||
logger.info(f"Searching slack conversations for query: {query}")
|
||||
results = self.client.search_messages(
|
||||
query=query,
|
||||
sort="timestamp",
|
||||
sort_dir="desc",
|
||||
count=self.config.get("count", 100),
|
||||
)
|
||||
|
||||
messages = results.get("messages")
|
||||
num_message = len(messages)
|
||||
logger.info(f"Found {num_message} messages for query: {query}")
|
||||
|
||||
matches = messages.get("matches", [])
|
||||
for message in matches:
|
||||
url = message.get("permalink")
|
||||
text = message.get("text")
|
||||
content = clean_string(text)
|
||||
|
||||
message_meta_data_keys = ["iid", "team", "ts", "type", "user", "username"]
|
||||
metadata = {}
|
||||
for key in message.keys():
|
||||
if key in message_meta_data_keys:
|
||||
metadata[key] = message.get(key)
|
||||
metadata.update({"url": url})
|
||||
|
||||
data.append(
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
data_content.append(content)
|
||||
doc_id = hashlib.md5((query + ", ".join(data_content)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in loading slack data: {e}")
|
||||
raise ValueError(
|
||||
f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
|
||||
) from e
|
||||
107
embedchain/embedchain/loaders/substack.py
Normal file
107
embedchain/embedchain/loaders/substack.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from xml.etree import ElementTree
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import is_readable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SubstackLoader(BaseLoader):
|
||||
"""
|
||||
This loader is used to load data from Substack URLs.
|
||||
"""
|
||||
|
||||
def load_data(self, url: str):
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.builder import ParserRejectedMarkup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Substack requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
|
||||
) from None
|
||||
|
||||
if not url.endswith("sitemap.xml"):
|
||||
url = url + "/sitemap.xml"
|
||||
|
||||
output = []
|
||||
response = requests.get(url)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise ValueError(
|
||||
f"""
|
||||
Failed to load {url}: {e}. Please use the root substack URL. For example, https://example.substack.com
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
ElementTree.fromstring(response.content)
|
||||
except ElementTree.ParseError:
|
||||
raise ValueError(
|
||||
f"""
|
||||
Failed to parse {url}. Please use the root substack URL. For example, https://example.substack.com
|
||||
"""
|
||||
)
|
||||
|
||||
soup = BeautifulSoup(response.text, "xml")
|
||||
links = [link.text for link in soup.find_all("loc") if link.parent.name == "url" and "/p/" in link.text]
|
||||
if len(links) == 0:
|
||||
links = [link.text for link in soup.find_all("loc") if "/p/" in link.text]
|
||||
|
||||
doc_id = hashlib.sha256((" ".join(links) + url).encode()).hexdigest()
|
||||
|
||||
def serialize_response(soup: BeautifulSoup):
|
||||
data = {}
|
||||
|
||||
h1_els = soup.find_all("h1")
|
||||
if h1_els is not None and len(h1_els) > 0:
|
||||
data["title"] = h1_els[1].text
|
||||
|
||||
description_el = soup.find("meta", {"name": "description"})
|
||||
if description_el is not None:
|
||||
data["description"] = description_el["content"]
|
||||
|
||||
content_el = soup.find("div", {"class": "available-content"})
|
||||
if content_el is not None:
|
||||
data["content"] = content_el.text
|
||||
|
||||
like_btn = soup.find("div", {"class": "like-button-container"})
|
||||
if like_btn is not None:
|
||||
no_of_likes_div = like_btn.find("div", {"class": "label"})
|
||||
if no_of_likes_div is not None:
|
||||
data["no_of_likes"] = no_of_likes_div.text
|
||||
|
||||
return data
|
||||
|
||||
def load_link(link: str):
|
||||
try:
|
||||
substack_data = requests.get(link)
|
||||
substack_data.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(substack_data.text, "html.parser")
|
||||
data = serialize_response(soup)
|
||||
data = str(data)
|
||||
if is_readable(data):
|
||||
return data
|
||||
else:
|
||||
logger.warning(f"Page is not readable (too many invalid characters): {link}")
|
||||
except ParserRejectedMarkup as e:
|
||||
logger.error(f"Failed to parse {link}: {e}")
|
||||
return None
|
||||
|
||||
for link in links:
|
||||
data = load_link(link)
|
||||
if data:
|
||||
output.append({"content": data, "meta_data": {"url": link}})
|
||||
# TODO: allow users to configure this
|
||||
time.sleep(1.0) # added to avoid rate limiting
|
||||
|
||||
return {"doc_id": doc_id, "data": output}
|
||||
30
embedchain/embedchain/loaders/text_file.py
Normal file
30
embedchain/embedchain/loaders/text_file.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class TextFileLoader(BaseLoader):
|
||||
def load_data(self, url: str):
|
||||
"""Load data from a text file located at a local path."""
|
||||
if not os.path.exists(url):
|
||||
raise FileNotFoundError(f"The file at {url} does not exist.")
|
||||
|
||||
with open(url, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
|
||||
metadata = {"url": url, "file_size": os.path.getsize(url), "file_type": url.split(".")[-1]}
|
||||
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
43
embedchain/embedchain/loaders/unstructured_file.py
Normal file
43
embedchain/embedchain/loaders/unstructured_file.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class UnstructuredLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from an Unstructured file."""
|
||||
try:
|
||||
import unstructured # noqa: F401
|
||||
from langchain_community.document_loaders import \
|
||||
UnstructuredFileLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
|
||||
) from None
|
||||
|
||||
loader = UnstructuredFileLoader(url)
|
||||
data = []
|
||||
all_content = []
|
||||
pages = loader.load_and_split()
|
||||
if not len(pages):
|
||||
raise ValueError("No data found")
|
||||
for page in pages:
|
||||
content = page.page_content
|
||||
content = clean_string(content)
|
||||
metadata = page.metadata
|
||||
metadata["url"] = url
|
||||
data.append(
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
all_content.append(content)
|
||||
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
100
embedchain/embedchain/loaders/web_page.py
Normal file
100
embedchain/embedchain/loaders/web_page.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Webpage requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
|
||||
) from None
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WebPageLoader(BaseLoader):
|
||||
# Shared session for all instances
|
||||
_session = requests.Session()
|
||||
|
||||
def load_data(self, url):
|
||||
"""Load data from a web page using a shared requests' session."""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
|
||||
}
|
||||
response = self._session.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
data = response.content
|
||||
content = self._get_clean_content(data, url)
|
||||
|
||||
metadata = {"url": url}
|
||||
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": [
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_clean_content(html, url) -> str:
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
original_size = len(str(soup.get_text()))
|
||||
|
||||
tags_to_exclude = [
|
||||
"nav",
|
||||
"aside",
|
||||
"form",
|
||||
"header",
|
||||
"noscript",
|
||||
"svg",
|
||||
"canvas",
|
||||
"footer",
|
||||
"script",
|
||||
"style",
|
||||
]
|
||||
for tag in soup(tags_to_exclude):
|
||||
tag.decompose()
|
||||
|
||||
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
|
||||
for id_ in ids_to_exclude:
|
||||
tags = soup.find_all(id=id_)
|
||||
for tag in tags:
|
||||
tag.decompose()
|
||||
|
||||
classes_to_exclude = [
|
||||
"elementor-location-header",
|
||||
"navbar-header",
|
||||
"nav",
|
||||
"header-sidebar-wrapper",
|
||||
"blog-sidebar-wrapper",
|
||||
"related-posts",
|
||||
]
|
||||
for class_name in classes_to_exclude:
|
||||
tags = soup.find_all(class_=class_name)
|
||||
for tag in tags:
|
||||
tag.decompose()
|
||||
|
||||
content = soup.get_text()
|
||||
content = clean_string(content)
|
||||
|
||||
cleaned_size = len(content)
|
||||
if original_size != 0:
|
||||
logger.info(
|
||||
f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
@classmethod
|
||||
def close_session(cls):
|
||||
cls._session.close()
|
||||
31
embedchain/embedchain/loaders/xml.py
Normal file
31
embedchain/embedchain/loaders/xml.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import hashlib
|
||||
|
||||
try:
|
||||
import unstructured # noqa: F401
|
||||
from langchain_community.document_loaders import UnstructuredXMLLoader
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'XML file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`'
|
||||
) from None
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class XmlLoader(BaseLoader):
|
||||
def load_data(self, xml_url):
|
||||
"""Load data from a XML file."""
|
||||
loader = UnstructuredXMLLoader(xml_url)
|
||||
data = loader.load()
|
||||
content = data[0].page_content
|
||||
content = clean_string(content)
|
||||
metadata = data[0].metadata
|
||||
metadata["url"] = metadata["source"]
|
||||
del metadata["source"]
|
||||
output = [{"content": content, "meta_data": metadata}]
|
||||
doc_id = hashlib.sha256((content + xml_url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": output,
|
||||
}
|
||||
79
embedchain/embedchain/loaders/youtube_channel.py
Normal file
79
embedchain/embedchain/loaders/youtube_channel.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class YoutubeChannelLoader(BaseLoader):
|
||||
"""Loader for youtube channel."""
|
||||
|
||||
def load_data(self, channel_name):
|
||||
try:
|
||||
import yt_dlp
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"YoutubeChannelLoader requires extra dependencies. Install with `pip install yt_dlp==2023.11.14 youtube-transcript-api==0.6.1`" # noqa: E501
|
||||
) from e
|
||||
|
||||
data = []
|
||||
data_urls = []
|
||||
youtube_url = f"https://www.youtube.com/{channel_name}/videos"
|
||||
youtube_video_loader = YoutubeVideoLoader()
|
||||
|
||||
def _get_yt_video_links():
|
||||
try:
|
||||
ydl_opts = {
|
||||
"quiet": True,
|
||||
"extract_flat": True,
|
||||
}
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
info_dict = ydl.extract_info(youtube_url, download=False)
|
||||
if "entries" in info_dict:
|
||||
videos = [entry["url"] for entry in info_dict["entries"]]
|
||||
return videos
|
||||
except Exception:
|
||||
logger.error(f"Failed to fetch youtube videos for channel: {channel_name}")
|
||||
return []
|
||||
|
||||
def _load_yt_video(video_link):
|
||||
try:
|
||||
each_load_data = youtube_video_loader.load_data(video_link)
|
||||
if each_load_data:
|
||||
return each_load_data.get("data")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load youtube video {video_link}: {e}")
|
||||
return None
|
||||
|
||||
def _add_youtube_channel():
|
||||
video_links = _get_yt_video_links()
|
||||
logger.info("Loading videos from youtube channel...")
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Submitting all tasks and storing the future object with the video link
|
||||
future_to_video = {
|
||||
executor.submit(_load_yt_video, video_link): video_link for video_link in video_links
|
||||
}
|
||||
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(future_to_video), total=len(video_links), desc="Processing videos"
|
||||
):
|
||||
video = future_to_video[future]
|
||||
try:
|
||||
results = future.result()
|
||||
if results:
|
||||
data.extend(results)
|
||||
data_urls.extend([result.get("meta_data").get("url") for result in results])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process youtube video {video}: {e}")
|
||||
|
||||
_add_youtube_channel()
|
||||
doc_id = hashlib.sha256((youtube_url + ", ".join(data_urls)).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": data,
|
||||
}
|
||||
57
embedchain/embedchain/loaders/youtube_video.py
Normal file
57
embedchain/embedchain/loaders/youtube_video.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
||||
try:
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
except ImportError:
|
||||
raise ImportError("YouTube video requires extra dependencies. Install with `pip install youtube-transcript-api`")
|
||||
try:
|
||||
from langchain_community.document_loaders import YoutubeLoader
|
||||
from langchain_community.document_loaders.youtube import _parse_video_id
|
||||
except ImportError:
|
||||
raise ImportError("YouTube video requires extra dependencies. Install with `pip install pytube==15.0.0`") from None
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils.misc import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a Youtube video."""
|
||||
video_id = _parse_video_id(url)
|
||||
|
||||
languages = ["en"]
|
||||
try:
|
||||
# Fetching transcript data
|
||||
languages = [transcript.language_code for transcript in YouTubeTranscriptApi.list_transcripts(video_id)]
|
||||
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=languages)
|
||||
# convert transcript to json to avoid unicode symboles
|
||||
transcript = json.dumps(transcript, ensure_ascii=True)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to fetch transcript for video {url}")
|
||||
transcript = "Unavailable"
|
||||
|
||||
loader = YoutubeLoader.from_youtube_url(url, add_video_info=True, language=languages)
|
||||
doc = loader.load()
|
||||
output = []
|
||||
if not len(doc):
|
||||
raise ValueError(f"No data found for url: {url}")
|
||||
content = doc[0].page_content
|
||||
content = clean_string(content)
|
||||
metadata = doc[0].metadata
|
||||
metadata["url"] = url
|
||||
metadata["transcript"] = transcript
|
||||
|
||||
output.append(
|
||||
{
|
||||
"content": content,
|
||||
"meta_data": metadata,
|
||||
}
|
||||
)
|
||||
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
|
||||
return {
|
||||
"doc_id": doc_id,
|
||||
"data": output,
|
||||
}
|
||||
Reference in New Issue
Block a user