[Feature] Google Drive Folder support as a data source (#1106)
This commit is contained in:
22
embedchain/chunkers/google_drive.py
Normal file
22
embedchain/chunkers/google_drive.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleDriveChunker(BaseChunker):
|
||||
"""Chunker for google drive folder."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
|
||||
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
|
||||
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
|
||||
DataType.GOOGLE_DRIVE: "embedchain.loaders.google_drive.GoogleDriveLoader",
|
||||
DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
|
||||
DataType.SLACK: "embedchain.loaders.slack.SlackLoader",
|
||||
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
|
||||
@@ -120,6 +121,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
|
||||
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
|
||||
DataType.GOOGLE_DRIVE: "embedchain.chunkers.google_drive.GoogleDriveChunker",
|
||||
DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
|
||||
54
embedchain/loaders/google_drive.py
Normal file
54
embedchain/loaders/google_drive.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
try:
|
||||
from googleapiclient.errors import HttpError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
|
||||
) from None
|
||||
|
||||
from langchain.document_loaders import GoogleDriveLoader as Loader, UnstructuredFileIOLoader
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleDriveLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _get_drive_id_from_url(url: str):
|
||||
regex = r"^https:\/\/drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
|
||||
if re.match(regex, url):
|
||||
return url.split("/")[-1]
|
||||
raise ValueError(
|
||||
f"The url provided {url} does not match a google drive folder url. Example drive url: "
|
||||
f"https://drive.google.com/drive/u/0/folders/xxxx"
|
||||
)
|
||||
|
||||
def load_data(self, url: str):
|
||||
"""Load data from a Google drive folder."""
|
||||
folder_id: str = self._get_drive_id_from_url(url)
|
||||
|
||||
try:
|
||||
loader = Loader(
|
||||
folder_id=folder_id,
|
||||
recursive=True,
|
||||
file_loader_cls=UnstructuredFileIOLoader,
|
||||
)
|
||||
|
||||
data = []
|
||||
all_content = []
|
||||
|
||||
docs = loader.load()
|
||||
for doc in docs:
|
||||
all_content.append(doc.page_content)
|
||||
# renames source to url for later use.
|
||||
doc.metadata["url"] = doc.metadata.pop("source")
|
||||
data.append({"content": doc.page_content, "meta_data": doc.metadata})
|
||||
|
||||
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": data}
|
||||
|
||||
except HttpError:
|
||||
raise FileNotFoundError("Unable to locate folder or files, check provided drive URL and try again")
|
||||
@@ -35,6 +35,7 @@ class IndirectDataType(Enum):
|
||||
CUSTOM = "custom"
|
||||
RSSFEED = "rss_feed"
|
||||
BEEHIIV = "beehiiv"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
DIRECTORY = "directory"
|
||||
SLACK = "slack"
|
||||
DROPBOX = "dropbox"
|
||||
@@ -73,6 +74,7 @@ class DataType(Enum):
|
||||
CUSTOM = IndirectDataType.CUSTOM.value
|
||||
RSSFEED = IndirectDataType.RSSFEED.value
|
||||
BEEHIIV = IndirectDataType.BEEHIIV.value
|
||||
GOOGLE_DRIVE = IndirectDataType.GOOGLE_DRIVE.value
|
||||
DIRECTORY = IndirectDataType.DIRECTORY.value
|
||||
SLACK = IndirectDataType.SLACK.value
|
||||
DROPBOX = IndirectDataType.DROPBOX.value
|
||||
|
||||
@@ -183,6 +183,11 @@ def detect_datatype(source: Any) -> DataType:
|
||||
# currently the following two fields are required in openapi spec yaml config
|
||||
return "openapi" in yaml_content and "info" in yaml_content
|
||||
|
||||
def is_google_drive_folder(url):
|
||||
# checks if url is a Google Drive folder url against a regex
|
||||
regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
|
||||
return re.match(regex, url)
|
||||
|
||||
try:
|
||||
if not isinstance(source, str):
|
||||
raise ValueError("Source is not a string and thus cannot be a URL.")
|
||||
@@ -196,8 +201,7 @@ def detect_datatype(source: Any) -> DataType:
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
@@ -266,6 +270,10 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `github`.")
|
||||
return DataType.GITHUB
|
||||
|
||||
if is_google_drive_folder(url.netloc + url.path):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
|
||||
return DataType.GOOGLE_DRIVE_FOLDER
|
||||
|
||||
# If none of the above conditions are met, it's a general web page
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
|
||||
return DataType.WEB_PAGE
|
||||
|
||||
Reference in New Issue
Block a user