[Feature] Google Drive Folder support as a data source (#1106)
This commit is contained in:
28
docs/components/data-sources/google-drive.mdx
Normal file
28
docs/components/data-sources/google-drive.mdx
Normal file
@@ -0,0 +1,28 @@
|
||||
---
|
||||
title: 'Google Drive'
|
||||
---
|
||||
|
||||
To use GoogleDriveLoader you must install the extra dependencies with `pip install --upgrade embedchain[googledrive]`.
|
||||
|
||||
The data_type must be `google_drive`. Otherwise, it will be considered a regular web page.
|
||||
|
||||
Google Drive requires the setup of credentials. This can be done by following the steps below:
|
||||
|
||||
1. Go to the [Google Cloud Console](https://console.cloud.google.com/apis/credentials).
|
||||
2. Create a project if you don't have one already.
|
||||
3. Enable the [Google Drive API](https://console.cloud.google.com/flows/enableapi?apiid=drive.googleapis.com)
|
||||
4. [Authorize credentials for desktop app](https://developers.google.com/drive/api/quickstart/python#authorize_credentials_for_a_desktop_application)
|
||||
5. When done, you will be able to download the credentials in `json` format. Rename the downloaded file to `credentials.json` and save it in `~/.credentials/credentials.json`
|
||||
6. Set the environment variable `GOOGLE_APPLICATION_CREDENTIALS=~/.credentials/credentials.json`
|
||||
|
||||
The first time you use the loader, you will be prompted to enter your Google account credentials.
|
||||
|
||||
|
||||
```python
|
||||
from embedchain import Pipeline as App
|
||||
|
||||
app = App()
|
||||
|
||||
url = "https://drive.google.com/drive/u/0/folders/xxx-xxx"
|
||||
app.add(url, data_type="google_drive")
|
||||
```
|
||||
22
embedchain/chunkers/google_drive.py
Normal file
22
embedchain/chunkers/google_drive.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleDriveChunker(BaseChunker):
|
||||
"""Chunker for google drive folder."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
@@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
|
||||
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
|
||||
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
|
||||
DataType.GOOGLE_DRIVE: "embedchain.loaders.google_drive.GoogleDriveLoader",
|
||||
DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
|
||||
DataType.SLACK: "embedchain.loaders.slack.SlackLoader",
|
||||
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
|
||||
@@ -120,6 +121,7 @@ class DataFormatter(JSONSerializable):
|
||||
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
|
||||
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
|
||||
DataType.GOOGLE_DRIVE: "embedchain.chunkers.google_drive.GoogleDriveChunker",
|
||||
DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",
|
||||
|
||||
54
embedchain/loaders/google_drive.py
Normal file
54
embedchain/loaders/google_drive.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
try:
|
||||
from googleapiclient.errors import HttpError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
|
||||
) from None
|
||||
|
||||
from langchain.document_loaders import GoogleDriveLoader as Loader, UnstructuredFileIOLoader
|
||||
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class GoogleDriveLoader(BaseLoader):
|
||||
@staticmethod
|
||||
def _get_drive_id_from_url(url: str):
|
||||
regex = r"^https:\/\/drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
|
||||
if re.match(regex, url):
|
||||
return url.split("/")[-1]
|
||||
raise ValueError(
|
||||
f"The url provided {url} does not match a google drive folder url. Example drive url: "
|
||||
f"https://drive.google.com/drive/u/0/folders/xxxx"
|
||||
)
|
||||
|
||||
def load_data(self, url: str):
|
||||
"""Load data from a Google drive folder."""
|
||||
folder_id: str = self._get_drive_id_from_url(url)
|
||||
|
||||
try:
|
||||
loader = Loader(
|
||||
folder_id=folder_id,
|
||||
recursive=True,
|
||||
file_loader_cls=UnstructuredFileIOLoader,
|
||||
)
|
||||
|
||||
data = []
|
||||
all_content = []
|
||||
|
||||
docs = loader.load()
|
||||
for doc in docs:
|
||||
all_content.append(doc.page_content)
|
||||
# renames source to url for later use.
|
||||
doc.metadata["url"] = doc.metadata.pop("source")
|
||||
data.append({"content": doc.page_content, "meta_data": doc.metadata})
|
||||
|
||||
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
|
||||
return {"doc_id": doc_id, "data": data}
|
||||
|
||||
except HttpError:
|
||||
raise FileNotFoundError("Unable to locate folder or files, check provided drive URL and try again")
|
||||
@@ -35,6 +35,7 @@ class IndirectDataType(Enum):
|
||||
CUSTOM = "custom"
|
||||
RSSFEED = "rss_feed"
|
||||
BEEHIIV = "beehiiv"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
DIRECTORY = "directory"
|
||||
SLACK = "slack"
|
||||
DROPBOX = "dropbox"
|
||||
@@ -73,6 +74,7 @@ class DataType(Enum):
|
||||
CUSTOM = IndirectDataType.CUSTOM.value
|
||||
RSSFEED = IndirectDataType.RSSFEED.value
|
||||
BEEHIIV = IndirectDataType.BEEHIIV.value
|
||||
GOOGLE_DRIVE = IndirectDataType.GOOGLE_DRIVE.value
|
||||
DIRECTORY = IndirectDataType.DIRECTORY.value
|
||||
SLACK = IndirectDataType.SLACK.value
|
||||
DROPBOX = IndirectDataType.DROPBOX.value
|
||||
|
||||
@@ -183,6 +183,11 @@ def detect_datatype(source: Any) -> DataType:
|
||||
# currently the following two fields are required in openapi spec yaml config
|
||||
return "openapi" in yaml_content and "info" in yaml_content
|
||||
|
||||
def is_google_drive_folder(url):
|
||||
# checks if url is a Google Drive folder url against a regex
|
||||
regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
|
||||
return re.match(regex, url)
|
||||
|
||||
try:
|
||||
if not isinstance(source, str):
|
||||
raise ValueError("Source is not a string and thus cannot be a URL.")
|
||||
@@ -196,8 +201,7 @@ def detect_datatype(source: Any) -> DataType:
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
@@ -266,6 +270,10 @@ def detect_datatype(source: Any) -> DataType:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `github`.")
|
||||
return DataType.GITHUB
|
||||
|
||||
if is_google_drive_folder(url.netloc + url.path):
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
|
||||
return DataType.GOOGLE_DRIVE_FOLDER
|
||||
|
||||
# If none of the above conditions are met, it's a general web page
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
|
||||
return DataType.WEB_PAGE
|
||||
|
||||
27
poetry.lock
generated
27
poetry.lock
generated
@@ -4677,6 +4677,32 @@ files = [
|
||||
{file = "protobuf-4.21.12.tar.gz", hash = "sha256:7cd532c4566d0e6feafecc1059d04c7915aec8e182d1cf7adee8b24ef1e2e6ab"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "5.9.5"
|
||||
description = "Cross-platform lib for process and system monitoring in Python."
|
||||
optional = true
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
files = [
|
||||
{file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"},
|
||||
{file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"},
|
||||
{file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"},
|
||||
{file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"},
|
||||
{file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"},
|
||||
{file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"},
|
||||
{file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"},
|
||||
{file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"},
|
||||
{file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"},
|
||||
{file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"},
|
||||
{file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"},
|
||||
{file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"},
|
||||
{file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"},
|
||||
{file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
||||
|
||||
[[package]]
|
||||
name = "psycopg"
|
||||
version = "3.1.12"
|
||||
@@ -8095,6 +8121,7 @@ elasticsearch = ["elasticsearch"]
|
||||
github = ["PyGithub", "gitpython"]
|
||||
gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"]
|
||||
google = ["google-generativeai"]
|
||||
googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"]
|
||||
huggingface-hub = ["huggingface_hub"]
|
||||
llama2 = ["replicate"]
|
||||
milvus = ["pymilvus"]
|
||||
|
||||
@@ -197,6 +197,7 @@ gmail = [
|
||||
"google-auth-httplib2",
|
||||
"google-api-core",
|
||||
]
|
||||
googledrive = ["google-api-python-client", "google-auth-oauthlib", "google-auth-httplib2"]
|
||||
postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
|
||||
mysql = ["mysql-connector-python"]
|
||||
github = ["PyGithub", "gitpython"]
|
||||
|
||||
@@ -3,6 +3,7 @@ from embedchain.chunkers.discourse import DiscourseChunker
|
||||
from embedchain.chunkers.docs_site import DocsSiteChunker
|
||||
from embedchain.chunkers.docx_file import DocxFileChunker
|
||||
from embedchain.chunkers.gmail import GmailChunker
|
||||
from embedchain.chunkers.google_drive import GoogleDriveChunker
|
||||
from embedchain.chunkers.json import JSONChunker
|
||||
from embedchain.chunkers.mdx import MdxChunker
|
||||
from embedchain.chunkers.notion import NotionChunker
|
||||
@@ -41,6 +42,7 @@ chunker_common_config = {
|
||||
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
|
||||
GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
|
||||
}
|
||||
|
||||
|
||||
|
||||
37
tests/loaders/test_google_drive.py
Normal file
37
tests/loaders/test_google_drive.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
|
||||
from embedchain.loaders.google_drive import GoogleDriveLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_drive_folder_loader():
|
||||
return GoogleDriveLoader()
|
||||
|
||||
|
||||
def test_load_data_invalid_drive_url(google_drive_folder_loader):
|
||||
mock_invalid_drive_url = "https://example.com"
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The url provided https://example.com does not match a google drive folder url. Example "
|
||||
"drive url: https://drive.google.com/drive/u/0/folders/xxxx",
|
||||
):
|
||||
google_drive_folder_loader.load_data(mock_invalid_drive_url)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
|
||||
def test_load_data_incorrect_drive_url(google_drive_folder_loader):
|
||||
mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx"
|
||||
with pytest.raises(
|
||||
FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again"
|
||||
):
|
||||
google_drive_folder_loader.load_data(mock_invalid_drive_url)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
|
||||
def test_load_data(google_drive_folder_loader):
|
||||
mock_valid_url = "YOUR_VALID_URL"
|
||||
result = google_drive_folder_loader.load_data(mock_valid_url)
|
||||
assert "doc_id" in result
|
||||
assert "data" in result
|
||||
assert "content" in result["data"][0]
|
||||
assert "meta_data" in result["data"][0]
|
||||
Reference in New Issue
Block a user