Rename embedchain to mem0 and open sourcing code for long term memory (#1474)

Co-authored-by: Deshraj Yadav <deshrajdry@gmail.com>
This commit is contained in:
Taranjeet Singh
2024-07-12 07:51:33 -07:00
committed by GitHub
parent 83e8c97295
commit f842a92e25
665 changed files with 9427 additions and 6592 deletions

View File

@@ -0,0 +1,53 @@
import hashlib
import os
import validators
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
try:
from deepgram import DeepgramClient, PrerecordedOptions
except ImportError:
raise ImportError(
"Audio file requires extra dependencies. Install with `pip install deepgram-sdk==3.2.7`"
) from None
@register_deserializable
class AudioLoader(BaseLoader):
def __init__(self):
if not os.environ.get("DEEPGRAM_API_KEY"):
raise ValueError("DEEPGRAM_API_KEY is not set")
DG_KEY = os.environ.get("DEEPGRAM_API_KEY")
self.client = DeepgramClient(DG_KEY)
def load_data(self, url: str):
"""Load data from a audio file or URL."""
options = PrerecordedOptions(
model="nova-2",
smart_format=True,
)
if validators.url(url):
source = {"url": url}
response = self.client.listen.prerecorded.v("1").transcribe_url(source, options)
else:
with open(url, "rb") as audio:
source = {"buffer": audio}
response = self.client.listen.prerecorded.v("1").transcribe_file(source, options)
content = response["results"]["channels"][0]["alternatives"][0]["transcript"]
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
metadata = {"url": url}
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": metadata,
}
],
}

View File

@@ -0,0 +1,12 @@
from embedchain.helpers.json_serializable import JSONSerializable
class BaseLoader(JSONSerializable):
def __init__(self):
pass
def load_data(self, url):
"""
Implemented by child classes
"""
pass

View File

@@ -0,0 +1,107 @@
import hashlib
import logging
import time
from xml.etree import ElementTree
import requests
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import is_readable
logger = logging.getLogger(__name__)
@register_deserializable
class BeehiivLoader(BaseLoader):
"""
This loader is used to load data from Beehiiv URLs.
"""
def load_data(self, url: str):
try:
from bs4 import BeautifulSoup
from bs4.builder import ParserRejectedMarkup
except ImportError:
raise ImportError(
"Beehiiv requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
) from None
if not url.endswith("sitemap.xml"):
url = url + "/sitemap.xml"
output = []
# we need to set this as a header to avoid 403
headers = {
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 "
"Safari/537.36"
),
}
response = requests.get(url, headers=headers)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise ValueError(
f"""
Failed to load {url}: {e}. Please use the root substack URL. For example, https://example.substack.com
"""
)
try:
ElementTree.fromstring(response.content)
except ElementTree.ParseError:
raise ValueError(
f"""
Failed to parse {url}. Please use the root substack URL. For example, https://example.substack.com
"""
)
soup = BeautifulSoup(response.text, "xml")
links = [link.text for link in soup.find_all("loc") if link.parent.name == "url" and "/p/" in link.text]
if len(links) == 0:
links = [link.text for link in soup.find_all("loc") if "/p/" in link.text]
doc_id = hashlib.sha256((" ".join(links) + url).encode()).hexdigest()
def serialize_response(soup: BeautifulSoup):
data = {}
h1_el = soup.find("h1")
if h1_el is not None:
data["title"] = h1_el.text
description_el = soup.find("meta", {"name": "description"})
if description_el is not None:
data["description"] = description_el["content"]
content_el = soup.find("div", {"id": "content-blocks"})
if content_el is not None:
data["content"] = content_el.text
return data
def load_link(link: str):
try:
beehiiv_data = requests.get(link, headers=headers)
beehiiv_data.raise_for_status()
soup = BeautifulSoup(beehiiv_data.text, "html.parser")
data = serialize_response(soup)
data = str(data)
if is_readable(data):
return data
else:
logger.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e:
logger.error(f"Failed to parse {link}: {e}")
return None
for link in links:
data = load_link(link)
if data:
output.append({"content": data, "meta_data": {"url": link}})
# TODO: allow users to configure this
time.sleep(1.0) # added to avoid rate limiting
return {"doc_id": doc_id, "data": output}

View File

@@ -0,0 +1,49 @@
import csv
import hashlib
from io import StringIO
from urllib.parse import urlparse
import requests
from embedchain.loaders.base_loader import BaseLoader
class CsvLoader(BaseLoader):
@staticmethod
def _detect_delimiter(first_line):
delimiters = [",", "\t", ";", "|"]
counts = {delimiter: first_line.count(delimiter) for delimiter in delimiters}
return max(counts, key=counts.get)
@staticmethod
def _get_file_content(content):
url = urlparse(content)
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
raise ValueError("Not a valid URL.")
if url.scheme in ["http", "https"]:
response = requests.get(content)
response.raise_for_status()
return StringIO(response.text)
elif url.scheme == "file":
path = url.path
return open(path, newline="", encoding="utf-8") # Open the file using the path from the URI
else:
return open(content, newline="", encoding="utf-8") # Treat content as a regular file path
@staticmethod
def load_data(content):
"""Load a csv file with headers. Each line is a document"""
result = []
lines = []
with CsvLoader._get_file_content(content) as file:
first_line = file.readline()
delimiter = CsvLoader._detect_delimiter(first_line)
file.seek(0) # Reset the file pointer to the start
reader = csv.DictReader(file, delimiter=delimiter)
for i, row in enumerate(reader):
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
lines.append(line)
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
return {"doc_id": doc_id, "data": result}

View File

@@ -0,0 +1,63 @@
import hashlib
import logging
from pathlib import Path
from typing import Any, Optional
from embedchain.config import AddConfig
from embedchain.data_formatter.data_formatter import DataFormatter
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.text_file import TextFileLoader
from embedchain.utils.misc import detect_datatype
logger = logging.getLogger(__name__)
@register_deserializable
class DirectoryLoader(BaseLoader):
"""Load data from a directory."""
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
config = config or {}
self.recursive = config.get("recursive", True)
self.extensions = config.get("extensions", None)
self.errors = []
def load_data(self, path: str):
directory_path = Path(path)
if not directory_path.is_dir():
raise ValueError(f"Invalid path: {path}")
logger.info(f"Loading data from directory: {path}")
data_list = self._process_directory(directory_path)
doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
for error in self.errors:
logger.warning(error)
return {"doc_id": doc_id, "data": data_list}
def _process_directory(self, directory_path: Path):
data_list = []
for file_path in directory_path.rglob("*") if self.recursive else directory_path.glob("*"):
# don't include dotfiles
if file_path.name.startswith("."):
continue
if file_path.is_file() and (not self.extensions or any(file_path.suffix == ext for ext in self.extensions)):
loader = self._predict_loader(file_path)
data_list.extend(loader.load_data(str(file_path))["data"])
elif file_path.is_dir():
logger.info(f"Loading data from directory: {file_path}")
return data_list
def _predict_loader(self, file_path: Path) -> BaseLoader:
try:
data_type = detect_datatype(str(file_path))
config = AddConfig()
return DataFormatter(data_type=data_type, config=config)._get_loader(
data_type=data_type, config=config.loader, loader=None
)
except Exception as e:
self.errors.append(f"Error processing {file_path}: {e}")
return TextFileLoader()

View File

@@ -0,0 +1,152 @@
import hashlib
import logging
import os
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
logger = logging.getLogger(__name__)
@register_deserializable
class DiscordLoader(BaseLoader):
"""
Load data from a Discord Channel ID.
"""
def __init__(self):
if not os.environ.get("DISCORD_TOKEN"):
raise ValueError("DISCORD_TOKEN is not set")
self.token = os.environ.get("DISCORD_TOKEN")
@staticmethod
def _format_message(message):
return {
"message_id": message.id,
"content": message.content,
"author": {
"id": message.author.id,
"name": message.author.name,
"discriminator": message.author.discriminator,
},
"created_at": message.created_at.isoformat(),
"attachments": [
{
"id": attachment.id,
"filename": attachment.filename,
"size": attachment.size,
"url": attachment.url,
"proxy_url": attachment.proxy_url,
"height": attachment.height,
"width": attachment.width,
}
for attachment in message.attachments
],
"embeds": [
{
"title": embed.title,
"type": embed.type,
"description": embed.description,
"url": embed.url,
"timestamp": embed.timestamp.isoformat(),
"color": embed.color,
"footer": {
"text": embed.footer.text,
"icon_url": embed.footer.icon_url,
"proxy_icon_url": embed.footer.proxy_icon_url,
},
"image": {
"url": embed.image.url,
"proxy_url": embed.image.proxy_url,
"height": embed.image.height,
"width": embed.image.width,
},
"thumbnail": {
"url": embed.thumbnail.url,
"proxy_url": embed.thumbnail.proxy_url,
"height": embed.thumbnail.height,
"width": embed.thumbnail.width,
},
"video": {
"url": embed.video.url,
"height": embed.video.height,
"width": embed.video.width,
},
"provider": {
"name": embed.provider.name,
"url": embed.provider.url,
},
"author": {
"name": embed.author.name,
"url": embed.author.url,
"icon_url": embed.author.icon_url,
"proxy_icon_url": embed.author.proxy_icon_url,
},
"fields": [
{
"name": field.name,
"value": field.value,
"inline": field.inline,
}
for field in embed.fields
],
}
for embed in message.embeds
],
}
def load_data(self, channel_id: str):
"""Load data from a Discord Channel ID."""
import discord
messages = []
class DiscordClient(discord.Client):
async def on_ready(self) -> None:
logger.info("Logged on as {0}!".format(self.user))
try:
channel = self.get_channel(int(channel_id))
if not isinstance(channel, discord.TextChannel):
raise ValueError(
f"Channel {channel_id} is not a text channel. " "Only text channels are supported for now."
)
threads = {}
for thread in channel.threads:
threads[thread.id] = thread
async for message in channel.history(limit=None):
messages.append(DiscordLoader._format_message(message))
if message.id in threads:
async for thread_message in threads[message.id].history(limit=None):
messages.append(DiscordLoader._format_message(thread_message))
except Exception as e:
logger.error(e)
await self.close()
finally:
await self.close()
intents = discord.Intents.default()
intents.message_content = True
client = DiscordClient(intents=intents)
client.run(self.token)
metadata = {
"url": channel_id,
}
messages = str(messages)
doc_id = hashlib.sha256((messages + channel_id).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": messages,
"meta_data": metadata,
}
],
}

View File

@@ -0,0 +1,79 @@
import hashlib
import logging
import time
from typing import Any, Optional
import requests
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class DiscourseLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(
"DiscourseLoader requires a config. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
)
self.domain = config.get("domain")
if not self.domain:
raise ValueError(
"DiscourseLoader requires a domain. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
)
def _check_query(self, query):
if not query or not isinstance(query, str):
raise ValueError(
"DiscourseLoader requires a query. Check the documentation for the correct format - `https://docs.embedchain.ai/components/data-sources/discourse`" # noqa: E501
)
def _load_post(self, post_id):
post_url = f"{self.domain}posts/{post_id}.json"
response = requests.get(post_url)
try:
response.raise_for_status()
except Exception as e:
logger.error(f"Failed to load post {post_id}: {e}")
return
response_data = response.json()
post_contents = clean_string(response_data.get("raw"))
metadata = {
"url": post_url,
"created_at": response_data.get("created_at", ""),
"username": response_data.get("username", ""),
"topic_slug": response_data.get("topic_slug", ""),
"score": response_data.get("score", ""),
}
data = {
"content": post_contents,
"meta_data": metadata,
}
return data
def load_data(self, query):
self._check_query(query)
data = []
data_contents = []
logger.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
search_url = f"{self.domain}search.json?q={query}"
response = requests.get(search_url)
try:
response.raise_for_status()
except Exception as e:
raise ValueError(f"Failed to search query {query}: {e}")
response_data = response.json()
post_ids = response_data.get("grouped_search_result").get("post_ids")
for id in post_ids:
post_data = self._load_post(id)
if post_data:
data.append(post_data)
data_contents.append(post_data.get("content"))
# Sleep for 0.4 sec, to avoid rate limiting. Check `https://meta.discourse.org/t/api-rate-limits/208405/6`
time.sleep(0.4)
doc_id = hashlib.sha256((query + ", ".join(data_contents)).encode()).hexdigest()
response_data = {"doc_id": doc_id, "data": data}
return response_data

View File

@@ -0,0 +1,119 @@
import hashlib
import logging
from urllib.parse import urljoin, urlparse
import requests
try:
from bs4 import BeautifulSoup
except ImportError:
raise ImportError(
"DocsSite requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
) from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
logger = logging.getLogger(__name__)
@register_deserializable
class DocsSiteLoader(BaseLoader):
def __init__(self):
self.visited_links = set()
def _get_child_links_recursive(self, url):
if url in self.visited_links:
return
parsed_url = urlparse(url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
current_path = parsed_url.path
response = requests.get(url)
if response.status_code != 200:
logger.info(f"Failed to fetch the website: {response.status_code}")
return
soup = BeautifulSoup(response.text, "html.parser")
all_links = (link.get("href") for link in soup.find_all("a", href=True))
child_links = (link for link in all_links if link.startswith(current_path) and link != current_path)
absolute_paths = set(urljoin(base_url, link) for link in child_links)
self.visited_links.update(absolute_paths)
[self._get_child_links_recursive(link) for link in absolute_paths if link not in self.visited_links]
def _get_all_urls(self, url):
self.visited_links = set()
self._get_child_links_recursive(url)
urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
return urls
@staticmethod
def _load_data_from_url(url: str) -> list:
response = requests.get(url)
if response.status_code != 200:
logger.info(f"Failed to fetch the website: {response.status_code}")
return []
soup = BeautifulSoup(response.content, "html.parser")
selectors = [
"article.bd-article",
'article[role="main"]',
"div.md-content",
'div[role="main"]',
"div.container",
"div.section",
"article",
"main",
]
output = []
for selector in selectors:
element = soup.select_one(selector)
if element:
content = element.prettify()
break
else:
content = soup.get_text()
soup = BeautifulSoup(content, "html.parser")
ignored_tags = [
"nav",
"aside",
"form",
"header",
"noscript",
"svg",
"canvas",
"footer",
"script",
"style",
]
for tag in soup(ignored_tags):
tag.decompose()
content = " ".join(soup.stripped_strings)
output.append(
{
"content": content,
"meta_data": {"url": url},
}
)
return output
def load_data(self, url):
all_urls = self._get_all_urls(url)
output = []
for u in all_urls:
output.extend(self._load_data_from_url(u))
doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -0,0 +1,26 @@
import hashlib
try:
from langchain_community.document_loaders import Docx2txtLoader
except ImportError:
raise ImportError("Docx file requires extra dependencies. Install with `pip install docx2txt==0.8`") from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class DocxFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a .docx file."""
loader = Docx2txtLoader(url)
output = []
data = loader.load()
content = data[0].page_content
metadata = data[0].metadata
metadata["url"] = "local"
output.append({"content": content, "meta_data": metadata})
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -0,0 +1,79 @@
import hashlib
import os
from dropbox.files import FileMetadata
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.directory_loader import DirectoryLoader
@register_deserializable
class DropboxLoader(BaseLoader):
def __init__(self):
access_token = os.environ.get("DROPBOX_ACCESS_TOKEN")
if not access_token:
raise ValueError("Please set the `DROPBOX_ACCESS_TOKEN` environment variable.")
try:
from dropbox import Dropbox, exceptions
except ImportError:
raise ImportError("Dropbox requires extra dependencies. Install with `pip install dropbox==11.36.2`")
try:
dbx = Dropbox(access_token)
dbx.users_get_current_account()
self.dbx = dbx
except exceptions.AuthError as ex:
raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex
def _download_folder(self, path: str, local_root: str) -> list[FileMetadata]:
"""Download a folder from Dropbox and save it preserving the directory structure."""
entries = self.dbx.files_list_folder(path).entries
for entry in entries:
local_path = os.path.join(local_root, entry.name)
if isinstance(entry, FileMetadata):
self.dbx.files_download_to_file(local_path, f"{path}/{entry.name}")
else:
os.makedirs(local_path, exist_ok=True)
self._download_folder(f"{path}/{entry.name}", local_path)
return entries
def _generate_dir_id_from_all_paths(self, path: str) -> str:
"""Generate a unique ID for a directory based on all of its paths."""
entries = self.dbx.files_list_folder(path).entries
paths = [f"{path}/{entry.name}" for entry in entries]
return hashlib.sha256("".join(paths).encode()).hexdigest()
def load_data(self, path: str):
"""Load data from a Dropbox URL, preserving the folder structure."""
root_dir = f"dropbox_{self._generate_dir_id_from_all_paths(path)}"
os.makedirs(root_dir, exist_ok=True)
for entry in self.dbx.files_list_folder(path).entries:
local_path = os.path.join(root_dir, entry.name)
if isinstance(entry, FileMetadata):
self.dbx.files_download_to_file(local_path, f"{path}/{entry.name}")
else:
os.makedirs(local_path, exist_ok=True)
self._download_folder(f"{path}/{entry.name}", local_path)
dir_loader = DirectoryLoader()
data = dir_loader.load_data(root_dir)["data"]
# Clean up
self._clean_directory(root_dir)
return {
"doc_id": hashlib.sha256(path.encode()).hexdigest(),
"data": data,
}
def _clean_directory(self, dir_path):
"""Recursively delete a directory and its contents."""
for item in os.listdir(dir_path):
item_path = os.path.join(dir_path, item)
if os.path.isdir(item_path):
self._clean_directory(item_path)
else:
os.remove(item_path)
os.rmdir(dir_path)

View File

@@ -0,0 +1,41 @@
import hashlib
import importlib.util
try:
import unstructured # noqa: F401
from langchain_community.document_loaders import UnstructuredExcelLoader
except ImportError:
raise ImportError(
'Excel file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`'
) from None
if importlib.util.find_spec("openpyxl") is None and importlib.util.find_spec("xlrd") is None:
raise ImportError("Excel file requires extra dependencies. Install with `pip install openpyxl xlrd`") from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
@register_deserializable
class ExcelFileLoader(BaseLoader):
def load_data(self, excel_url):
"""Load data from a Excel file."""
loader = UnstructuredExcelLoader(excel_url)
pages = loader.load_and_split()
data = []
for page in pages:
content = page.page_content
content = clean_string(content)
metadata = page.metadata
metadata["url"] = excel_url
data.append({"content": content, "meta_data": metadata})
doc_id = hashlib.sha256((content + excel_url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -0,0 +1,312 @@
import concurrent.futures
import hashlib
import logging
import re
import shlex
from typing import Any, Optional
from tqdm import tqdm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
GITHUB_URL = "https://github.com"
GITHUB_API_URL = "https://api.github.com"
VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion", "branch", "file"])
class GithubLoader(BaseLoader):
"""Load data from GitHub search query."""
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(
"GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
)
try:
from github import Github
except ImportError as e:
raise ValueError(
"GithubLoader requires extra dependencies. \
Install with `pip install gitpython==3.1.38 PyGithub==1.59.1`"
) from e
self.config = config
token = config.get("token")
if not token:
raise ValueError(
"GithubLoader requires a personal access token to use github api. Check - `https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic`" # noqa: E501
)
try:
self.client = Github(token)
except Exception as e:
logging.error(f"GithubLoader failed to initialize client: {e}")
self.client = None
def _github_search_code(self, query: str):
"""Search GitHub code."""
data = []
results = self.client.search_code(query)
for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
url = result.html_url
logging.info(f"Added data from url: {url}")
content = result.decoded_content.decode("utf-8")
metadata = {
"url": url,
}
data.append(
{
"content": clean_string(content),
"meta_data": metadata,
}
)
return data
def _get_github_repo_data(self, repo_name: str, branch_name: str = None, file_path: str = None) -> list[dict]:
"""Get file contents from Repo"""
data = []
repo = self.client.get_repo(repo_name)
repo_contents = repo.get_contents("")
if branch_name:
repo_contents = repo.get_contents("", ref=branch_name)
if file_path:
repo_contents = [repo.get_contents(file_path)]
with tqdm(desc="Loading files:", unit="item") as progress_bar:
while repo_contents:
file_content = repo_contents.pop(0)
if file_content.type == "dir":
try:
repo_contents.extend(repo.get_contents(file_content.path))
except Exception:
logging.warning(f"Failed to read directory: {file_content.path}")
progress_bar.update(1)
continue
else:
try:
file_text = file_content.decoded_content.decode()
except Exception:
logging.warning(f"Failed to read file: {file_content.path}")
progress_bar.update(1)
continue
file_path = file_content.path
data.append(
{
"content": clean_string(file_text),
"meta_data": {
"path": file_path,
},
}
)
progress_bar.update(1)
return data
def _github_search_repo(self, query: str) -> list[dict]:
"""Search GitHub repo."""
logging.info(f"Searching github repos with query: {query}")
updated_query = query.split(":")[-1]
data = self._get_github_repo_data(updated_query)
return data
def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
"""Search GitHub issues and PRs."""
data = []
query = f"{query} is:{type}"
logging.info(f"Searching github for query: {query}")
results = self.client.search_issues(query)
logging.info(f"Total results: {results.totalCount}")
for result in tqdm(results, total=results.totalCount, desc=f"Loading {type} from github"):
url = result.html_url
title = result.title
body = result.body
if not body:
logging.warning(f"Skipping issue because empty content for: {url}")
continue
labels = " ".join([label.name for label in result.labels])
issue_comments = result.get_comments()
comments = []
comments_created_at = []
for comment in issue_comments:
comments_created_at.append(str(comment.created_at))
comments.append(f"{comment.user.name}:{comment.body}")
content = "\n".join([title, labels, body, *comments])
metadata = {
"url": url,
"created_at": str(result.created_at),
"comments_created_at": " ".join(comments_created_at),
}
data.append(
{
"content": clean_string(content),
"meta_data": metadata,
}
)
return data
# need to test more for discussion
def _github_search_discussions(self, query: str):
"""Search GitHub discussions."""
data = []
query = f"{query} is:discussion"
logging.info(f"Searching github repo for query: {query}")
repos_results = self.client.search_repositories(query)
logging.info(f"Total repos found: {repos_results.totalCount}")
for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
teams = repo_result.get_teams()
for team in teams:
team_discussions = team.get_discussions()
for discussion in team_discussions:
url = discussion.html_url
title = discussion.title
body = discussion.body
if not body:
logging.warning(f"Skipping discussion because empty content for: {url}")
continue
comments = []
comments_created_at = []
print("Discussion comments: ", discussion.comments_url)
content = "\n".join([title, body, *comments])
metadata = {
"url": url,
"created_at": str(discussion.created_at),
"comments_created_at": " ".join(comments_created_at),
}
data.append(
{
"content": clean_string(content),
"meta_data": metadata,
}
)
return data
def _get_github_repo_branch(self, query: str, type: str) -> list[dict]:
"""Get file contents for specific branch"""
logging.info(f"Searching github repo for query: {query} is:{type}")
pattern = r"repo:(\S+) name:(\S+)"
match = re.search(pattern, query)
if match:
repo_name = match.group(1)
branch_name = match.group(2)
else:
raise ValueError(
f"Repository name and Branch name not found, instead found this \
Repo: {repo_name}, Branch: {branch_name}"
)
data = self._get_github_repo_data(repo_name=repo_name, branch_name=branch_name)
return data
def _get_github_repo_file(self, query: str, type: str) -> list[dict]:
"""Get specific file content"""
logging.info(f"Searching github repo for query: {query} is:{type}")
pattern = r"repo:(\S+) path:(\S+)"
match = re.search(pattern, query)
if match:
repo_name = match.group(1)
file_path = match.group(2)
else:
raise ValueError(
f"Repository name and File name not found, instead found this Repo: {repo_name}, File: {file_path}"
)
data = self._get_github_repo_data(repo_name=repo_name, file_path=file_path)
return data
def _search_github_data(self, search_type: str, query: str):
"""Search github data."""
if search_type == "code":
data = self._github_search_code(query)
elif search_type == "repo":
data = self._github_search_repo(query)
elif search_type == "issue":
data = self._github_search_issues_and_pr(query, search_type)
elif search_type == "pr":
data = self._github_search_issues_and_pr(query, search_type)
elif search_type == "branch":
data = self._get_github_repo_branch(query, search_type)
elif search_type == "file":
data = self._get_github_repo_file(query, search_type)
elif search_type == "discussion":
raise ValueError("GithubLoader does not support searching discussions yet.")
else:
raise NotImplementedError(f"{search_type} not supported")
return data
@staticmethod
def _get_valid_github_query(query: str):
"""Check if query is valid and return search types and valid GitHub query."""
query_terms = shlex.split(query)
# query must provide repo to load data from
if len(query_terms) < 1 or "repo:" not in query:
raise ValueError(
"GithubLoader requires a search query with `repo:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
)
github_query = []
types = set()
type_pattern = r"type:([a-zA-Z,]+)"
for term in query_terms:
term_match = re.search(type_pattern, term)
if term_match:
search_types = term_match.group(1).split(",")
types.update(search_types)
else:
github_query.append(term)
# query must provide search type
if len(types) == 0:
raise ValueError(
"GithubLoader requires a search query with `type:` term. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
)
for search_type in search_types:
if search_type not in VALID_SEARCH_TYPES:
raise ValueError(
f"Invalid search type: {search_type}. Valid types are: {', '.join(VALID_SEARCH_TYPES)}"
)
query = " ".join(github_query)
return types, query
def load_data(self, search_query: str, max_results: int = 1000):
"""Load data from GitHub search query."""
if not self.client:
raise ValueError(
"GithubLoader client is not initialized, data will not be loaded. Refer docs - `https://docs.embedchain.ai/data-sources/github`" # noqa: E501
)
search_types, query = self._get_valid_github_query(search_query)
logging.info(f"Searching github for query: {query}, with types: {', '.join(search_types)}")
data = []
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
futures_map = executor.map(self._search_github_data, search_types, [query] * len(search_types))
for search_data in tqdm(futures_map, total=len(search_types), desc="Searching data from github"):
data.extend(search_data)
return {
"doc_id": hashlib.sha256(query.encode()).hexdigest(),
"data": data,
}

View File

@@ -0,0 +1,144 @@
import base64
import hashlib
import logging
import os
from email import message_from_bytes
from email.utils import parsedate_to_datetime
from textwrap import dedent
from typing import Optional
from bs4 import BeautifulSoup
try:
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
except ImportError:
raise ImportError(
'Gmail requires extra dependencies. Install with `pip install --upgrade "embedchain[gmail]"`'
) from None
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class GmailReader:
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
def __init__(self, query: str, service=None, results_per_page: int = 10):
self.query = query
self.service = service or self._initialize_service()
self.results_per_page = results_per_page
@staticmethod
def _initialize_service():
credentials = GmailReader._get_credentials()
return build("gmail", "v1", credentials=credentials)
@staticmethod
def _get_credentials():
if not os.path.exists("credentials.json"):
raise FileNotFoundError("Missing 'credentials.json'. Download it from your Google Developer account.")
creds = (
Credentials.from_authorized_user_file("token.json", GmailReader.SCOPES)
if os.path.exists("token.json")
else None
)
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file("credentials.json", GmailReader.SCOPES)
creds = flow.run_local_server(port=8080)
with open("token.json", "w") as token:
token.write(creds.to_json())
return creds
def load_emails(self) -> list[dict]:
response = self.service.users().messages().list(userId="me", q=self.query).execute()
messages = response.get("messages", [])
return [self._parse_email(self._get_email(message["id"])) for message in messages]
def _get_email(self, message_id: str):
raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute()
return base64.urlsafe_b64decode(raw_message["raw"])
def _parse_email(self, raw_email) -> dict:
mime_msg = message_from_bytes(raw_email)
return {
"subject": self._get_header(mime_msg, "Subject"),
"from": self._get_header(mime_msg, "From"),
"to": self._get_header(mime_msg, "To"),
"date": self._format_date(mime_msg),
"body": self._get_body(mime_msg),
}
@staticmethod
def _get_header(mime_msg, header_name: str) -> str:
return mime_msg.get(header_name, "")
@staticmethod
def _format_date(mime_msg) -> Optional[str]:
date_header = GmailReader._get_header(mime_msg, "Date")
return parsedate_to_datetime(date_header).isoformat() if date_header else None
@staticmethod
def _get_body(mime_msg) -> str:
def decode_payload(part):
charset = part.get_content_charset() or "utf-8"
try:
return part.get_payload(decode=True).decode(charset)
except UnicodeDecodeError:
return part.get_payload(decode=True).decode(charset, errors="replace")
if mime_msg.is_multipart():
for part in mime_msg.walk():
ctype = part.get_content_type()
cdispo = str(part.get("Content-Disposition"))
if ctype == "text/plain" and "attachment" not in cdispo:
return decode_payload(part)
elif ctype == "text/html":
return decode_payload(part)
else:
return decode_payload(mime_msg)
return ""
class GmailLoader(BaseLoader):
def load_data(self, query: str):
reader = GmailReader(query=query)
emails = reader.load_emails()
logger.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
data = []
for email in emails:
content = self._process_email(email)
data.append({"content": content, "meta_data": email})
return {"doc_id": self._generate_doc_id(query, data), "data": data}
@staticmethod
def _process_email(email: dict) -> str:
content = BeautifulSoup(email["body"], "html.parser").get_text()
content = clean_string(content)
return dedent(
f"""
Email from '{email['from']}' to '{email['to']}'
Subject: {email['subject']}
Date: {email['date']}
Content: {content}
"""
)
@staticmethod
def _generate_doc_id(query: str, data: list[dict]) -> str:
content_strings = [email["content"] for email in data]
return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest()

View File

@@ -0,0 +1,62 @@
import hashlib
import re
try:
from googleapiclient.errors import HttpError
except ImportError:
raise ImportError(
"Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
) from None
from langchain_community.document_loaders import GoogleDriveLoader as Loader
try:
import unstructured # noqa: F401
from langchain_community.document_loaders import UnstructuredFileIOLoader
except ImportError:
raise ImportError(
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
) from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class GoogleDriveLoader(BaseLoader):
@staticmethod
def _get_drive_id_from_url(url: str):
regex = r"^https:\/\/drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
if re.match(regex, url):
return url.split("/")[-1]
raise ValueError(
f"The url provided {url} does not match a google drive folder url. Example drive url: "
f"https://drive.google.com/drive/u/0/folders/xxxx"
)
def load_data(self, url: str):
"""Load data from a Google drive folder."""
folder_id: str = self._get_drive_id_from_url(url)
try:
loader = Loader(
folder_id=folder_id,
recursive=True,
file_loader_cls=UnstructuredFileIOLoader,
)
data = []
all_content = []
docs = loader.load()
for doc in docs:
all_content.append(doc.page_content)
# renames source to url for later use.
doc.metadata["url"] = doc.metadata.pop("source")
data.append({"content": doc.page_content, "meta_data": doc.metadata})
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
return {"doc_id": doc_id, "data": data}
except HttpError:
raise FileNotFoundError("Unable to locate folder or files, check provided drive URL and try again")

View File

@@ -0,0 +1,50 @@
import base64
import hashlib
import os
from pathlib import Path
from openai import OpenAI
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
DESCRIBE_IMAGE_PROMPT = "Describe the image:"
@register_deserializable
class ImageLoader(BaseLoader):
def __init__(self, max_tokens: int = 500, api_key: str = None, prompt: str = None):
super().__init__()
self.custom_prompt = prompt or DESCRIBE_IMAGE_PROMPT
self.max_tokens = max_tokens
self.api_key = api_key or os.environ["OPENAI_API_KEY"]
self.client = OpenAI(api_key=self.api_key)
@staticmethod
def _encode_image(image_path: str):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def _create_completion_request(self, content: str):
return self.client.chat.completions.create(
model="gpt-4o", messages=[{"role": "user", "content": content}], max_tokens=self.max_tokens
)
def _process_url(self, url: str):
if url.startswith("http"):
return [{"type": "text", "text": self.custom_prompt}, {"type": "image_url", "image_url": {"url": url}}]
elif Path(url).is_file():
extension = Path(url).suffix.lstrip(".")
encoded_image = self._encode_image(url)
image_data = f"data:image/{extension};base64,{encoded_image}"
return [{"type": "text", "text": self.custom_prompt}, {"type": "image", "image_url": {"url": image_data}}]
else:
raise ValueError(f"Invalid URL or file path: {url}")
def load_data(self, url: str):
content = self._process_url(url)
response = self._create_completion_request(content)
content = response.choices[0].message.content
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {"doc_id": doc_id, "data": [{"content": content, "meta_data": {"url": url, "type": "image"}}]}

View File

@@ -0,0 +1,93 @@
import hashlib
import json
import os
import re
from typing import Union
import requests
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string, is_valid_json_string
class JSONReader:
def __init__(self) -> None:
"""Initialize the JSONReader."""
pass
@staticmethod
def load_data(json_data: Union[dict, str]) -> list[str]:
"""Load data from a JSON structure.
Args:
json_data (Union[dict, str]): The JSON data to load.
Returns:
list[str]: A list of strings representing the leaf nodes of the JSON.
"""
if isinstance(json_data, str):
json_data = json.loads(json_data)
else:
json_data = json_data
json_output = json.dumps(json_data, indent=0)
lines = json_output.split("\n")
useful_lines = [line for line in lines if not re.match(r"^[{}\[\],]*$", line)]
return ["\n".join(useful_lines)]
VALID_URL_PATTERN = (
"^https?://(?:www\.)?(?:\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|[a-zA-Z0-9.-]+)(?::\d+)?/(?:[^/\s]+/)*[^/\s]+\.json$"
)
class JSONLoader(BaseLoader):
@staticmethod
def _check_content(content):
if not isinstance(content, str):
raise ValueError(
"Invaid content input. \
If you want to upload (list, dict, etc.), do \
`json.dump(data, indent=0)` and add the stringified JSON. \
Check - `https://docs.embedchain.ai/data-sources/json`"
)
@staticmethod
def load_data(content):
"""Load a json file. Each data point is a key value pair."""
JSONLoader._check_content(content)
loader = JSONReader()
data = []
data_content = []
content_url_str = content
if os.path.isfile(content):
with open(content, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
elif re.match(VALID_URL_PATTERN, content):
response = requests.get(content)
if response.status_code == 200:
json_data = response.json()
else:
raise ValueError(
f"Loading data from the given url: {content} failed. \
Make sure the url is working."
)
elif is_valid_json_string(content):
json_data = content
content_url_str = hashlib.sha256((content).encode("utf-8")).hexdigest()
else:
raise ValueError(f"Invalid content to load json data from: {content}")
docs = loader.load_data(json_data)
for doc in docs:
text = doc if isinstance(doc, str) else doc["text"]
doc_content = clean_string(text)
data.append({"content": doc_content, "meta_data": {"url": content_url_str}})
data_content.append(doc_content)
doc_id = hashlib.sha256((content_url_str + ", ".join(data_content)).encode()).hexdigest()
return {"doc_id": doc_id, "data": data}

View File

@@ -0,0 +1,24 @@
import hashlib
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class LocalQnaPairLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local QnA pair."""
question, answer = content
content = f"Q: {question}\nA: {answer}"
url = "local"
metadata = {"url": url, "question": question}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": metadata,
}
],
}

View File

@@ -0,0 +1,24 @@
import hashlib
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class LocalTextLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local text file."""
url = "local"
metadata = {
"url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": metadata,
}
],
}

View File

@@ -0,0 +1,25 @@
import hashlib
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class MdxLoader(BaseLoader):
def load_data(self, url):
"""Load data from a mdx file."""
with open(url, "r", encoding="utf-8") as infile:
content = infile.read()
metadata = {
"url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": metadata,
}
],
}

View File

@@ -0,0 +1,67 @@
import hashlib
import logging
from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class MySQLLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]]):
super().__init__()
if not config:
raise ValueError(
f"Invalid sql config: {config}.",
"Provide the correct config, refer `https://docs.embedchain.ai/data-sources/mysql`.",
)
self.config = config
self.connection = None
self.cursor = None
self._setup_loader(config=config)
def _setup_loader(self, config: dict[str, Any]):
try:
import mysql.connector as sqlconnector
except ImportError as e:
raise ImportError(
"Unable to import required packages for MySQL loader. Run `pip install --upgrade 'embedchain[mysql]'`." # noqa: E501
) from e
try:
self.connection = sqlconnector.connection.MySQLConnection(**config)
self.cursor = self.connection.cursor()
except (sqlconnector.Error, IOError) as err:
logger.info(f"Connection failed: {err}")
raise ValueError(
f"Unable to connect with the given config: {config}.",
"Please provide the correct configuration to load data from you MySQL DB. \
Refer `https://docs.embedchain.ai/data-sources/mysql`.",
)
@staticmethod
def _check_query(query):
if not isinstance(query, str):
raise ValueError(
f"Invalid mysql query: {query}",
"Provide the valid query to add from mysql, \
make sure you are following `https://docs.embedchain.ai/data-sources/mysql`",
)
def load_data(self, query):
self._check_query(query=query)
data = []
data_content = []
self.cursor.execute(query)
rows = self.cursor.fetchall()
for row in rows:
doc_content = clean_string(str(row))
data.append({"content": doc_content, "meta_data": {"url": query}})
data_content.append(doc_content)
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -0,0 +1,121 @@
import hashlib
import logging
import os
from typing import Any, Optional
import requests
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
class NotionDocument:
"""
A simple Document class to hold the text and additional information of a page.
"""
def __init__(self, text: str, extra_info: dict[str, Any]):
self.text = text
self.extra_info = extra_info
class NotionPageLoader:
"""
Notion Page Loader.
Reads a set of Notion pages.
"""
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
def __init__(self, integration_token: Optional[str] = None) -> None:
"""Initialize with Notion integration token."""
if integration_token is None:
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
if integration_token is None:
raise ValueError(
"Must specify `integration_token` or set environment " "variable `NOTION_INTEGRATION_TOKEN`."
)
self.token = integration_token
self.headers = {
"Authorization": "Bearer " + self.token,
"Content-Type": "application/json",
"Notion-Version": "2022-06-28",
}
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
"""Read a block from Notion."""
done = False
result_lines_arr = []
cur_block_id = block_id
while not done:
block_url = self.BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
res = requests.get(block_url, headers=self.headers)
data = res.json()
for result in data["results"]:
result_type = result["type"]
result_obj = result[result_type]
cur_result_text_arr = []
if "rich_text" in result_obj:
for rich_text in result_obj["rich_text"]:
if "text" in rich_text:
text = rich_text["text"]["content"]
prefix = "\t" * num_tabs
cur_result_text_arr.append(prefix + text)
result_block_id = result["id"]
has_children = result["has_children"]
if has_children:
children_text = self._read_block(result_block_id, num_tabs=num_tabs + 1)
cur_result_text_arr.append(children_text)
cur_result_text = "\n".join(cur_result_text_arr)
result_lines_arr.append(cur_result_text)
if data["next_cursor"] is None:
done = True
else:
cur_block_id = data["next_cursor"]
result_lines = "\n".join(result_lines_arr)
return result_lines
def load_data(self, page_ids: list[str]) -> list[NotionDocument]:
"""Load data from the given list of page IDs."""
docs = []
for page_id in page_ids:
page_text = self._read_block(page_id)
docs.append(NotionDocument(text=page_text, extra_info={"page_id": page_id}))
return docs
@register_deserializable
class NotionLoader(BaseLoader):
def load_data(self, source):
"""Load data from a Notion URL."""
id = source[-32:]
formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}"
logger.debug(f"Extracted notion page id as: {formatted_id}")
integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
reader = NotionPageLoader(integration_token=integration_token)
documents = reader.load_data(page_ids=[formatted_id])
raw_text = documents[0].text
text = clean_string(raw_text)
doc_id = hashlib.sha256((text + source).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": text,
"meta_data": {"url": f"notion-{formatted_id}"},
}
],
}

View File

@@ -0,0 +1,42 @@
import hashlib
from io import StringIO
from urllib.parse import urlparse
import requests
import yaml
from embedchain.loaders.base_loader import BaseLoader
class OpenAPILoader(BaseLoader):
@staticmethod
def _get_file_content(content):
url = urlparse(content)
if all([url.scheme, url.netloc]) and url.scheme not in ["file", "http", "https"]:
raise ValueError("Not a valid URL.")
if url.scheme in ["http", "https"]:
response = requests.get(content)
response.raise_for_status()
return StringIO(response.text)
elif url.scheme == "file":
path = url.path
return open(path)
else:
return open(content)
@staticmethod
def load_data(content):
"""Load yaml file of openapi. Each pair is a document."""
data = []
file_path = content
data_content = []
with OpenAPILoader._get_file_content(content=content) as file:
yaml_data = yaml.load(file, Loader=yaml.SafeLoader)
for i, (key, value) in enumerate(yaml_data.items()):
string_data = f"{key}: {value}"
metadata = {"url": file_path, "row": i + 1}
data.append({"content": string_data, "meta_data": metadata})
data_content.append(string_data)
doc_id = hashlib.sha256((content + ", ".join(data_content)).encode()).hexdigest()
return {"doc_id": doc_id, "data": data}

View File

@@ -0,0 +1,38 @@
import hashlib
from langchain_community.document_loaders import PyPDFLoader
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
@register_deserializable
class PdfFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a PDF file."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
}
loader = PyPDFLoader(url, headers=headers)
data = []
all_content = []
pages = loader.load_and_split()
if not len(pages):
raise ValueError("No data found")
for page in pages:
content = page.page_content
content = clean_string(content)
metadata = page.metadata
metadata["url"] = url
data.append(
{
"content": content,
"meta_data": metadata,
}
)
all_content.append(content)
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -0,0 +1,73 @@
import hashlib
import logging
from typing import Any, Optional
from embedchain.loaders.base_loader import BaseLoader
logger = logging.getLogger(__name__)
class PostgresLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
if not config:
raise ValueError(f"Must provide the valid config. Received: {config}")
self.connection = None
self.cursor = None
self._setup_loader(config=config)
def _setup_loader(self, config: dict[str, Any]):
try:
import psycopg
except ImportError as e:
raise ImportError(
"Unable to import required packages. \
Run `pip install --upgrade 'embedchain[postgres]'`"
) from e
if "url" in config:
config_info = config.get("url")
else:
conn_params = []
for key, value in config.items():
conn_params.append(f"{key}={value}")
config_info = " ".join(conn_params)
logger.info(f"Connecting to postrgres sql: {config_info}")
self.connection = psycopg.connect(conninfo=config_info)
self.cursor = self.connection.cursor()
@staticmethod
def _check_query(query):
if not isinstance(query, str):
raise ValueError(
f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`", # noqa:E501
)
def load_data(self, query):
self._check_query(query)
try:
data = []
data_content = []
self.cursor.execute(query)
results = self.cursor.fetchall()
for result in results:
doc_content = str(result)
data.append({"content": doc_content, "meta_data": {"url": query}})
data_content.append(doc_content)
doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}
except Exception as e:
raise ValueError(f"Failed to load data using query={query} with: {e}")
def close_connection(self):
if self.cursor:
self.cursor.close()
self.cursor = None
if self.connection:
self.connection.close()
self.connection = None

View File

@@ -0,0 +1,52 @@
import hashlib
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class RSSFeedLoader(BaseLoader):
"""Loader for RSS Feed."""
def load_data(self, url):
"""Load data from a rss feed."""
output = self.get_rss_content(url)
doc_id = hashlib.sha256((str(output) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}
@staticmethod
def serialize_metadata(metadata):
for key, value in metadata.items():
if not isinstance(value, (str, int, float, bool)):
metadata[key] = str(value)
return metadata
@staticmethod
def get_rss_content(url: str):
try:
from langchain_community.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
except ImportError:
raise ImportError(
"""RSSFeedLoader file requires extra dependencies.
Install with `pip install feedparser==6.0.10 newspaper3k==0.2.8 listparser==0.19`"""
) from None
output = []
loader = LangchainRSSFeedLoader(urls=[url])
data = loader.load()
for entry in data:
metadata = RSSFeedLoader.serialize_metadata(entry.metadata)
metadata.update({"url": url})
output.append(
{
"content": entry.page_content,
"meta_data": metadata,
}
)
return output

View File

@@ -0,0 +1,79 @@
import concurrent.futures
import hashlib
import logging
import os
from urllib.parse import urlparse
import requests
from tqdm import tqdm
try:
from bs4 import BeautifulSoup
from bs4.builder import ParserRejectedMarkup
except ImportError:
raise ImportError(
"Sitemap requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
) from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.web_page import WebPageLoader
logger = logging.getLogger(__name__)
@register_deserializable
class SitemapLoader(BaseLoader):
"""
This method takes a sitemap URL or local file path as input and retrieves
all the URLs to use the WebPageLoader to load content
of each page.
"""
def load_data(self, sitemap_source):
output = []
web_page_loader = WebPageLoader()
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
}
if urlparse(sitemap_source).scheme in ("http", "https"):
try:
response = requests.get(sitemap_source, headers=headers)
response.raise_for_status()
soup = BeautifulSoup(response.text, "xml")
except requests.RequestException as e:
logger.error(f"Error fetching sitemap from URL: {e}")
return
elif os.path.isfile(sitemap_source):
with open(sitemap_source, "r") as file:
soup = BeautifulSoup(file, "xml")
else:
raise ValueError("Invalid sitemap source. Please provide a valid URL or local file path.")
links = [link.text for link in soup.find_all("loc") if link.parent.name == "url"]
if len(links) == 0:
links = [link.text for link in soup.find_all("loc")]
doc_id = hashlib.sha256((" ".join(links) + sitemap_source).encode()).hexdigest()
def load_web_page(link):
try:
loader_data = web_page_loader.load_data(link)
return loader_data.get("data")
except ParserRejectedMarkup as e:
logger.error(f"Failed to parse {link}: {e}")
return None
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_link = {executor.submit(load_web_page, link): link for link in links}
for future in tqdm(concurrent.futures.as_completed(future_to_link), total=len(links), desc="Loading pages"):
link = future_to_link[future]
try:
data = future.result()
if data:
output.extend(data)
except Exception as e:
logger.error(f"Error loading page {link}: {e}")
return {"doc_id": doc_id, "data": output}

View File

@@ -0,0 +1,115 @@
import hashlib
import logging
import os
import ssl
from typing import Any, Optional
import certifi
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
SLACK_API_BASE_URL = "https://www.slack.com/api/"
logger = logging.getLogger(__name__)
class SlackLoader(BaseLoader):
def __init__(self, config: Optional[dict[str, Any]] = None):
super().__init__()
self.config = config if config else {}
if "base_url" not in self.config:
self.config["base_url"] = SLACK_API_BASE_URL
self.client = None
self._setup_loader(self.config)
def _setup_loader(self, config: dict[str, Any]):
try:
from slack_sdk import WebClient
except ImportError as e:
raise ImportError(
"Slack loader requires extra dependencies. \
Install with `pip install --upgrade embedchain[slack]`"
) from e
if os.getenv("SLACK_USER_TOKEN") is None:
raise ValueError(
"SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
)
logger.info(f"Creating Slack Loader with config: {config}")
# get slack client config params
slack_bot_token = os.getenv("SLACK_USER_TOKEN")
ssl_cert = ssl.create_default_context(cafile=certifi.where())
base_url = config.get("base_url", SLACK_API_BASE_URL)
headers = config.get("headers")
# for Org-Wide App
team_id = config.get("team_id")
self.client = WebClient(
token=slack_bot_token,
base_url=base_url,
ssl=ssl_cert,
headers=headers,
team_id=team_id,
)
logger.info("Slack Loader setup successful!")
@staticmethod
def _check_query(query):
if not isinstance(query, str):
raise ValueError(
f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
)
def load_data(self, query):
self._check_query(query)
try:
data = []
data_content = []
logger.info(f"Searching slack conversations for query: {query}")
results = self.client.search_messages(
query=query,
sort="timestamp",
sort_dir="desc",
count=self.config.get("count", 100),
)
messages = results.get("messages")
num_message = len(messages)
logger.info(f"Found {num_message} messages for query: {query}")
matches = messages.get("matches", [])
for message in matches:
url = message.get("permalink")
text = message.get("text")
content = clean_string(text)
message_meta_data_keys = ["iid", "team", "ts", "type", "user", "username"]
metadata = {}
for key in message.keys():
if key in message_meta_data_keys:
metadata[key] = message.get(key)
metadata.update({"url": url})
data.append(
{
"content": content,
"meta_data": metadata,
}
)
data_content.append(content)
doc_id = hashlib.md5((query + ", ".join(data_content)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}
except Exception as e:
logger.warning(f"Error in loading slack data: {e}")
raise ValueError(
f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more." # noqa:E501
) from e

View File

@@ -0,0 +1,107 @@
import hashlib
import logging
import time
from xml.etree import ElementTree
import requests
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import is_readable
logger = logging.getLogger(__name__)
@register_deserializable
class SubstackLoader(BaseLoader):
"""
This loader is used to load data from Substack URLs.
"""
def load_data(self, url: str):
try:
from bs4 import BeautifulSoup
from bs4.builder import ParserRejectedMarkup
except ImportError:
raise ImportError(
"Substack requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
) from None
if not url.endswith("sitemap.xml"):
url = url + "/sitemap.xml"
output = []
response = requests.get(url)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise ValueError(
f"""
Failed to load {url}: {e}. Please use the root substack URL. For example, https://example.substack.com
"""
)
try:
ElementTree.fromstring(response.content)
except ElementTree.ParseError:
raise ValueError(
f"""
Failed to parse {url}. Please use the root substack URL. For example, https://example.substack.com
"""
)
soup = BeautifulSoup(response.text, "xml")
links = [link.text for link in soup.find_all("loc") if link.parent.name == "url" and "/p/" in link.text]
if len(links) == 0:
links = [link.text for link in soup.find_all("loc") if "/p/" in link.text]
doc_id = hashlib.sha256((" ".join(links) + url).encode()).hexdigest()
def serialize_response(soup: BeautifulSoup):
data = {}
h1_els = soup.find_all("h1")
if h1_els is not None and len(h1_els) > 0:
data["title"] = h1_els[1].text
description_el = soup.find("meta", {"name": "description"})
if description_el is not None:
data["description"] = description_el["content"]
content_el = soup.find("div", {"class": "available-content"})
if content_el is not None:
data["content"] = content_el.text
like_btn = soup.find("div", {"class": "like-button-container"})
if like_btn is not None:
no_of_likes_div = like_btn.find("div", {"class": "label"})
if no_of_likes_div is not None:
data["no_of_likes"] = no_of_likes_div.text
return data
def load_link(link: str):
try:
substack_data = requests.get(link)
substack_data.raise_for_status()
soup = BeautifulSoup(substack_data.text, "html.parser")
data = serialize_response(soup)
data = str(data)
if is_readable(data):
return data
else:
logger.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e:
logger.error(f"Failed to parse {link}: {e}")
return None
for link in links:
data = load_link(link)
if data:
output.append({"content": data, "meta_data": {"url": link}})
# TODO: allow users to configure this
time.sleep(1.0) # added to avoid rate limiting
return {"doc_id": doc_id, "data": output}

View File

@@ -0,0 +1,30 @@
import hashlib
import os
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class TextFileLoader(BaseLoader):
def load_data(self, url: str):
"""Load data from a text file located at a local path."""
if not os.path.exists(url):
raise FileNotFoundError(f"The file at {url} does not exist.")
with open(url, "r", encoding="utf-8") as file:
content = file.read()
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
metadata = {"url": url, "file_size": os.path.getsize(url), "file_type": url.split(".")[-1]}
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": metadata,
}
],
}

View File

@@ -0,0 +1,43 @@
import hashlib
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
@register_deserializable
class UnstructuredLoader(BaseLoader):
def load_data(self, url):
"""Load data from an Unstructured file."""
try:
import unstructured # noqa: F401
from langchain_community.document_loaders import \
UnstructuredFileLoader
except ImportError:
raise ImportError(
'Unstructured file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`' # noqa: E501
) from None
loader = UnstructuredFileLoader(url)
data = []
all_content = []
pages = loader.load_and_split()
if not len(pages):
raise ValueError("No data found")
for page in pages:
content = page.page_content
content = clean_string(content)
metadata = page.metadata
metadata["url"] = url
data.append(
{
"content": content,
"meta_data": metadata,
}
)
all_content.append(content)
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -0,0 +1,100 @@
import hashlib
import logging
import requests
try:
from bs4 import BeautifulSoup
except ImportError:
raise ImportError(
"Webpage requires extra dependencies. Install with `pip install beautifulsoup4==4.12.3`"
) from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
logger = logging.getLogger(__name__)
@register_deserializable
class WebPageLoader(BaseLoader):
# Shared session for all instances
_session = requests.Session()
def load_data(self, url):
"""Load data from a web page using a shared requests' session."""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.102 Safari/537.36", # noqa:E501
}
response = self._session.get(url, headers=headers, timeout=30)
response.raise_for_status()
data = response.content
content = self._get_clean_content(data, url)
metadata = {"url": url}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": metadata,
}
],
}
@staticmethod
def _get_clean_content(html, url) -> str:
soup = BeautifulSoup(html, "html.parser")
original_size = len(str(soup.get_text()))
tags_to_exclude = [
"nav",
"aside",
"form",
"header",
"noscript",
"svg",
"canvas",
"footer",
"script",
"style",
]
for tag in soup(tags_to_exclude):
tag.decompose()
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
for id_ in ids_to_exclude:
tags = soup.find_all(id=id_)
for tag in tags:
tag.decompose()
classes_to_exclude = [
"elementor-location-header",
"navbar-header",
"nav",
"header-sidebar-wrapper",
"blog-sidebar-wrapper",
"related-posts",
]
for class_name in classes_to_exclude:
tags = soup.find_all(class_=class_name)
for tag in tags:
tag.decompose()
content = soup.get_text()
content = clean_string(content)
cleaned_size = len(content)
if original_size != 0:
logger.info(
f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
)
return content
@classmethod
def close_session(cls):
cls._session.close()

View File

@@ -0,0 +1,31 @@
import hashlib
try:
import unstructured # noqa: F401
from langchain_community.document_loaders import UnstructuredXMLLoader
except ImportError:
raise ImportError(
'XML file requires extra dependencies. Install with `pip install "unstructured[local-inference, all-docs]"`'
) from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
@register_deserializable
class XmlLoader(BaseLoader):
def load_data(self, xml_url):
"""Load data from a XML file."""
loader = UnstructuredXMLLoader(xml_url)
data = loader.load()
content = data[0].page_content
content = clean_string(content)
metadata = data[0].metadata
metadata["url"] = metadata["source"]
del metadata["source"]
output = [{"content": content, "meta_data": metadata}]
doc_id = hashlib.sha256((content + xml_url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -0,0 +1,79 @@
import concurrent.futures
import hashlib
import logging
from tqdm import tqdm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.youtube_video import YoutubeVideoLoader
logger = logging.getLogger(__name__)
class YoutubeChannelLoader(BaseLoader):
"""Loader for youtube channel."""
def load_data(self, channel_name):
try:
import yt_dlp
except ImportError as e:
raise ValueError(
"YoutubeChannelLoader requires extra dependencies. Install with `pip install yt_dlp==2023.11.14 youtube-transcript-api==0.6.1`" # noqa: E501
) from e
data = []
data_urls = []
youtube_url = f"https://www.youtube.com/{channel_name}/videos"
youtube_video_loader = YoutubeVideoLoader()
def _get_yt_video_links():
try:
ydl_opts = {
"quiet": True,
"extract_flat": True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(youtube_url, download=False)
if "entries" in info_dict:
videos = [entry["url"] for entry in info_dict["entries"]]
return videos
except Exception:
logger.error(f"Failed to fetch youtube videos for channel: {channel_name}")
return []
def _load_yt_video(video_link):
try:
each_load_data = youtube_video_loader.load_data(video_link)
if each_load_data:
return each_load_data.get("data")
except Exception as e:
logger.error(f"Failed to load youtube video {video_link}: {e}")
return None
def _add_youtube_channel():
video_links = _get_yt_video_links()
logger.info("Loading videos from youtube channel...")
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submitting all tasks and storing the future object with the video link
future_to_video = {
executor.submit(_load_yt_video, video_link): video_link for video_link in video_links
}
for future in tqdm(
concurrent.futures.as_completed(future_to_video), total=len(video_links), desc="Processing videos"
):
video = future_to_video[future]
try:
results = future.result()
if results:
data.extend(results)
data_urls.extend([result.get("meta_data").get("url") for result in results])
except Exception as e:
logger.error(f"Failed to process youtube video {video}: {e}")
_add_youtube_channel()
doc_id = hashlib.sha256((youtube_url + ", ".join(data_urls)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}

View File

@@ -0,0 +1,57 @@
import hashlib
import json
import logging
try:
from youtube_transcript_api import YouTubeTranscriptApi
except ImportError:
raise ImportError("YouTube video requires extra dependencies. Install with `pip install youtube-transcript-api`")
try:
from langchain_community.document_loaders import YoutubeLoader
from langchain_community.document_loaders.youtube import _parse_video_id
except ImportError:
raise ImportError("YouTube video requires extra dependencies. Install with `pip install pytube==15.0.0`") from None
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils.misc import clean_string
@register_deserializable
class YoutubeVideoLoader(BaseLoader):
def load_data(self, url):
"""Load data from a Youtube video."""
video_id = _parse_video_id(url)
languages = ["en"]
try:
# Fetching transcript data
languages = [transcript.language_code for transcript in YouTubeTranscriptApi.list_transcripts(video_id)]
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=languages)
# convert transcript to json to avoid unicode symboles
transcript = json.dumps(transcript, ensure_ascii=True)
except Exception:
logging.exception(f"Failed to fetch transcript for video {url}")
transcript = "Unavailable"
loader = YoutubeLoader.from_youtube_url(url, add_video_info=True, language=languages)
doc = loader.load()
output = []
if not len(doc):
raise ValueError(f"No data found for url: {url}")
content = doc[0].page_content
content = clean_string(content)
metadata = doc[0].metadata
metadata["url"] = url
metadata["transcript"] = transcript
output.append(
{
"content": content,
"meta_data": metadata,
}
)
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}