[Feature] Google Drive Folder support as a data source (#1106)

This commit is contained in:
Joe Sleiman
2024-01-05 08:16:01 +02:00
committed by GitHub
parent 38ad57a22c
commit b4ec14382b
10 changed files with 185 additions and 2 deletions

View File

@@ -0,0 +1,28 @@
---
title: 'Google Drive'
---
To use GoogleDriveLoader you must install the extra dependencies with `pip install --upgrade embedchain[googledrive]`.
The data_type must be `google_drive`. Otherwise, it will be considered a regular web page.
Google Drive requires the setup of credentials. This can be done by following the steps below:
1. Go to the [Google Cloud Console](https://console.cloud.google.com/apis/credentials).
2. Create a project if you don't have one already.
3. Enable the [Google Drive API](https://console.cloud.google.com/flows/enableapi?apiid=drive.googleapis.com)
4. [Authorize credentials for desktop app](https://developers.google.com/drive/api/quickstart/python#authorize_credentials_for_a_desktop_application)
5. When done, you will be able to download the credentials in `json` format. Rename the downloaded file to `credentials.json` and save it in `~/.credentials/credentials.json`
6. Set the environment variable `GOOGLE_APPLICATION_CREDENTIALS=~/.credentials/credentials.json`
The first time you use the loader, you will be prompted to enter your Google account credentials.
```python
from embedchain import Pipeline as App
app = App()
url = "https://drive.google.com/drive/u/0/folders/xxx-xxx"
app.add(url, data_type="google_drive")
```

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable
class GoogleDriveChunker(BaseChunker):
"""Chunker for google drive folder."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable):
DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
DataType.GOOGLE_DRIVE: "embedchain.loaders.google_drive.GoogleDriveLoader",
DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader",
DataType.SLACK: "embedchain.loaders.slack.SlackLoader",
DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader",
@@ -120,6 +121,7 @@ class DataFormatter(JSONSerializable):
DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
DataType.GOOGLE_DRIVE: "embedchain.chunkers.google_drive.GoogleDriveChunker",
DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker",
DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker",

View File

@@ -0,0 +1,54 @@
import hashlib
import re
try:
from googleapiclient.errors import HttpError
except ImportError:
raise ImportError(
"Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`"
) from None
from langchain.document_loaders import GoogleDriveLoader as Loader, UnstructuredFileIOLoader
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class GoogleDriveLoader(BaseLoader):
@staticmethod
def _get_drive_id_from_url(url: str):
regex = r"^https:\/\/drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
if re.match(regex, url):
return url.split("/")[-1]
raise ValueError(
f"The url provided {url} does not match a google drive folder url. Example drive url: "
f"https://drive.google.com/drive/u/0/folders/xxxx"
)
def load_data(self, url: str):
"""Load data from a Google drive folder."""
folder_id: str = self._get_drive_id_from_url(url)
try:
loader = Loader(
folder_id=folder_id,
recursive=True,
file_loader_cls=UnstructuredFileIOLoader,
)
data = []
all_content = []
docs = loader.load()
for doc in docs:
all_content.append(doc.page_content)
# renames source to url for later use.
doc.metadata["url"] = doc.metadata.pop("source")
data.append({"content": doc.page_content, "meta_data": doc.metadata})
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
return {"doc_id": doc_id, "data": data}
except HttpError:
raise FileNotFoundError("Unable to locate folder or files, check provided drive URL and try again")

View File

@@ -35,6 +35,7 @@ class IndirectDataType(Enum):
CUSTOM = "custom"
RSSFEED = "rss_feed"
BEEHIIV = "beehiiv"
GOOGLE_DRIVE = "google_drive"
DIRECTORY = "directory"
SLACK = "slack"
DROPBOX = "dropbox"
@@ -73,6 +74,7 @@ class DataType(Enum):
CUSTOM = IndirectDataType.CUSTOM.value
RSSFEED = IndirectDataType.RSSFEED.value
BEEHIIV = IndirectDataType.BEEHIIV.value
GOOGLE_DRIVE = IndirectDataType.GOOGLE_DRIVE.value
DIRECTORY = IndirectDataType.DIRECTORY.value
SLACK = IndirectDataType.SLACK.value
DROPBOX = IndirectDataType.DROPBOX.value

View File

@@ -183,6 +183,11 @@ def detect_datatype(source: Any) -> DataType:
# currently the following two fields are required in openapi spec yaml config
return "openapi" in yaml_content and "info" in yaml_content
def is_google_drive_folder(url):
# checks if url is a Google Drive folder url against a regex
regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$"
return re.match(regex, url)
try:
if not isinstance(source, str):
raise ValueError("Source is not a string and thus cannot be a URL.")
@@ -196,8 +201,7 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30)
if url:
from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -266,6 +270,10 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `github`.")
return DataType.GITHUB
if is_google_drive_folder(url.netloc + url.path):
logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
return DataType.GOOGLE_DRIVE_FOLDER
# If none of the above conditions are met, it's a general web page
logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
return DataType.WEB_PAGE

27
poetry.lock generated
View File

@@ -4677,6 +4677,32 @@ files = [
{file = "protobuf-4.21.12.tar.gz", hash = "sha256:7cd532c4566d0e6feafecc1059d04c7915aec8e182d1cf7adee8b24ef1e2e6ab"},
]
[[package]]
name = "psutil"
version = "5.9.5"
description = "Cross-platform lib for process and system monitoring in Python."
optional = true
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
files = [
{file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"},
{file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"},
{file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"},
{file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"},
{file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"},
{file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"},
{file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"},
{file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"},
{file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"},
{file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"},
{file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"},
{file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"},
{file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"},
{file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"},
]
[package.extras]
test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]]
name = "psycopg"
version = "3.1.12"
@@ -8095,6 +8121,7 @@ elasticsearch = ["elasticsearch"]
github = ["PyGithub", "gitpython"]
gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"]
google = ["google-generativeai"]
googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"]
huggingface-hub = ["huggingface_hub"]
llama2 = ["replicate"]
milvus = ["pymilvus"]

View File

@@ -197,6 +197,7 @@ gmail = [
"google-auth-httplib2",
"google-api-core",
]
googledrive = ["google-api-python-client", "google-auth-oauthlib", "google-auth-httplib2"]
postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
mysql = ["mysql-connector-python"]
github = ["PyGithub", "gitpython"]

View File

@@ -3,6 +3,7 @@ from embedchain.chunkers.discourse import DiscourseChunker
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.gmail import GmailChunker
from embedchain.chunkers.google_drive import GoogleDriveChunker
from embedchain.chunkers.json import JSONChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker
@@ -41,6 +42,7 @@ chunker_common_config = {
SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
}

View File

@@ -0,0 +1,37 @@
import pytest
from embedchain.loaders.google_drive import GoogleDriveLoader
@pytest.fixture
def google_drive_folder_loader():
return GoogleDriveLoader()
def test_load_data_invalid_drive_url(google_drive_folder_loader):
mock_invalid_drive_url = "https://example.com"
with pytest.raises(
ValueError,
match="The url provided https://example.com does not match a google drive folder url. Example "
"drive url: https://drive.google.com/drive/u/0/folders/xxxx",
):
google_drive_folder_loader.load_data(mock_invalid_drive_url)
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
def test_load_data_incorrect_drive_url(google_drive_folder_loader):
mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx"
with pytest.raises(
FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again"
):
google_drive_folder_loader.load_data(mock_invalid_drive_url)
@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.")
def test_load_data(google_drive_folder_loader):
mock_valid_url = "YOUR_VALID_URL"
result = google_drive_folder_loader.load_data(mock_valid_url)
assert "doc_id" in result
assert "data" in result
assert "content" in result["data"][0]
assert "meta_data" in result["data"][0]