From b4ec14382bec0e6c2998ad2968083fb867e79930 Mon Sep 17 00:00:00 2001 From: Joe Sleiman <31307130+JoeSL@users.noreply.github.com> Date: Fri, 5 Jan 2024 08:16:01 +0200 Subject: [PATCH] [Feature] Google Drive Folder support as a data source (#1106) --- docs/components/data-sources/google-drive.mdx | 28 ++++++++++ embedchain/chunkers/google_drive.py | 22 ++++++++ embedchain/data_formatter/data_formatter.py | 2 + embedchain/loaders/google_drive.py | 54 +++++++++++++++++++ embedchain/models/data_type.py | 2 + embedchain/utils.py | 12 ++++- poetry.lock | 27 ++++++++++ pyproject.toml | 1 + tests/chunkers/test_chunkers.py | 2 + tests/loaders/test_google_drive.py | 37 +++++++++++++ 10 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 docs/components/data-sources/google-drive.mdx create mode 100644 embedchain/chunkers/google_drive.py create mode 100644 embedchain/loaders/google_drive.py create mode 100644 tests/loaders/test_google_drive.py diff --git a/docs/components/data-sources/google-drive.mdx b/docs/components/data-sources/google-drive.mdx new file mode 100644 index 00000000..ed0482ed --- /dev/null +++ b/docs/components/data-sources/google-drive.mdx @@ -0,0 +1,28 @@ +--- +title: 'Google Drive' +--- + +To use GoogleDriveLoader you must install the extra dependencies with `pip install --upgrade embedchain[googledrive]`. + +The data_type must be `google_drive`. Otherwise, it will be considered a regular web page. + +Google Drive requires the setup of credentials. This can be done by following the steps below: + +1. Go to the [Google Cloud Console](https://console.cloud.google.com/apis/credentials). +2. Create a project if you don't have one already. +3. Enable the [Google Drive API](https://console.cloud.google.com/flows/enableapi?apiid=drive.googleapis.com) +4. [Authorize credentials for desktop app](https://developers.google.com/drive/api/quickstart/python#authorize_credentials_for_a_desktop_application) +5. When done, you will be able to download the credentials in `json` format. Rename the downloaded file to `credentials.json` and save it in `~/.credentials/credentials.json` +6. Set the environment variable `GOOGLE_APPLICATION_CREDENTIALS=~/.credentials/credentials.json` + +The first time you use the loader, you will be prompted to enter your Google account credentials. + + +```python +from embedchain import Pipeline as App + +app = App() + +url = "https://drive.google.com/drive/u/0/folders/xxx-xxx" +app.add(url, data_type="google_drive") +``` \ No newline at end of file diff --git a/embedchain/chunkers/google_drive.py b/embedchain/chunkers/google_drive.py new file mode 100644 index 00000000..8440325b --- /dev/null +++ b/embedchain/chunkers/google_drive.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class GoogleDriveChunker(BaseChunker): + """Chunker for google drive folder.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 9ec7c258..afa0e286 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -74,6 +74,7 @@ class DataFormatter(JSONSerializable): DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader", DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader", DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader", + DataType.GOOGLE_DRIVE: "embedchain.loaders.google_drive.GoogleDriveLoader", DataType.DIRECTORY: "embedchain.loaders.directory_loader.DirectoryLoader", DataType.SLACK: "embedchain.loaders.slack.SlackLoader", DataType.DROPBOX: "embedchain.loaders.dropbox.DropboxLoader", @@ -120,6 +121,7 @@ class DataFormatter(JSONSerializable): DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker", DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker", DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker", + DataType.GOOGLE_DRIVE: "embedchain.chunkers.google_drive.GoogleDriveChunker", DataType.DIRECTORY: "embedchain.chunkers.common_chunker.CommonChunker", DataType.SLACK: "embedchain.chunkers.common_chunker.CommonChunker", DataType.DROPBOX: "embedchain.chunkers.common_chunker.CommonChunker", diff --git a/embedchain/loaders/google_drive.py b/embedchain/loaders/google_drive.py new file mode 100644 index 00000000..dde405aa --- /dev/null +++ b/embedchain/loaders/google_drive.py @@ -0,0 +1,54 @@ +import hashlib +import re + +try: + from googleapiclient.errors import HttpError +except ImportError: + raise ImportError( + "Google Drive requires extra dependencies. Install with `pip install embedchain[googledrive]`" + ) from None + +from langchain.document_loaders import GoogleDriveLoader as Loader, UnstructuredFileIOLoader + +from embedchain.helpers.json_serializable import register_deserializable +from embedchain.loaders.base_loader import BaseLoader + + +@register_deserializable +class GoogleDriveLoader(BaseLoader): + @staticmethod + def _get_drive_id_from_url(url: str): + regex = r"^https:\/\/drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$" + if re.match(regex, url): + return url.split("/")[-1] + raise ValueError( + f"The url provided {url} does not match a google drive folder url. Example drive url: " + f"https://drive.google.com/drive/u/0/folders/xxxx" + ) + + def load_data(self, url: str): + """Load data from a Google drive folder.""" + folder_id: str = self._get_drive_id_from_url(url) + + try: + loader = Loader( + folder_id=folder_id, + recursive=True, + file_loader_cls=UnstructuredFileIOLoader, + ) + + data = [] + all_content = [] + + docs = loader.load() + for doc in docs: + all_content.append(doc.page_content) + # renames source to url for later use. + doc.metadata["url"] = doc.metadata.pop("source") + data.append({"content": doc.page_content, "meta_data": doc.metadata}) + + doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest() + return {"doc_id": doc_id, "data": data} + + except HttpError: + raise FileNotFoundError("Unable to locate folder or files, check provided drive URL and try again") diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 23a1fffc..df2e655a 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -35,6 +35,7 @@ class IndirectDataType(Enum): CUSTOM = "custom" RSSFEED = "rss_feed" BEEHIIV = "beehiiv" + GOOGLE_DRIVE = "google_drive" DIRECTORY = "directory" SLACK = "slack" DROPBOX = "dropbox" @@ -73,6 +74,7 @@ class DataType(Enum): CUSTOM = IndirectDataType.CUSTOM.value RSSFEED = IndirectDataType.RSSFEED.value BEEHIIV = IndirectDataType.BEEHIIV.value + GOOGLE_DRIVE = IndirectDataType.GOOGLE_DRIVE.value DIRECTORY = IndirectDataType.DIRECTORY.value SLACK = IndirectDataType.SLACK.value DROPBOX = IndirectDataType.DROPBOX.value diff --git a/embedchain/utils.py b/embedchain/utils.py index 4169558e..43665d54 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -183,6 +183,11 @@ def detect_datatype(source: Any) -> DataType: # currently the following two fields are required in openapi spec yaml config return "openapi" in yaml_content and "info" in yaml_content + def is_google_drive_folder(url): + # checks if url is a Google Drive folder url against a regex + regex = r"^drive\.google\.com\/drive\/(?:u\/\d+\/)folders\/([a-zA-Z0-9_-]+)$" + return re.match(regex, url) + try: if not isinstance(source, str): raise ValueError("Source is not a string and thus cannot be a URL.") @@ -196,8 +201,7 @@ def detect_datatype(source: Any) -> DataType: formatted_source = format_source(str(source), 30) if url: - from langchain.document_loaders.youtube import \ - ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS + from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS if url.netloc in YOUTUBE_ALLOWED_NETLOCS: logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") @@ -266,6 +270,10 @@ def detect_datatype(source: Any) -> DataType: logging.debug(f"Source of `{formatted_source}` detected as `github`.") return DataType.GITHUB + if is_google_drive_folder(url.netloc + url.path): + logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.") + return DataType.GOOGLE_DRIVE_FOLDER + # If none of the above conditions are met, it's a general web page logging.debug(f"Source of `{formatted_source}` detected as `web_page`.") return DataType.WEB_PAGE diff --git a/poetry.lock b/poetry.lock index 99c28892..f6eb3bef 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4677,6 +4677,32 @@ files = [ {file = "protobuf-4.21.12.tar.gz", hash = "sha256:7cd532c4566d0e6feafecc1059d04c7915aec8e182d1cf7adee8b24ef1e2e6ab"}, ] +[[package]] +name = "psutil" +version = "5.9.5" +description = "Cross-platform lib for process and system monitoring in Python." +optional = true +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "psutil-5.9.5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:be8929ce4313f9f8146caad4272f6abb8bf99fc6cf59344a3167ecd74f4f203f"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ab8ed1a1d77c95453db1ae00a3f9c50227ebd955437bcf2a574ba8adbf6a74d5"}, + {file = "psutil-5.9.5-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:4aef137f3345082a3d3232187aeb4ac4ef959ba3d7c10c33dd73763fbc063da4"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:ea8518d152174e1249c4f2a1c89e3e6065941df2fa13a1ab45327716a23c2b48"}, + {file = "psutil-5.9.5-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:acf2aef9391710afded549ff602b5887d7a2349831ae4c26be7c807c0a39fac4"}, + {file = "psutil-5.9.5-cp27-none-win32.whl", hash = "sha256:5b9b8cb93f507e8dbaf22af6a2fd0ccbe8244bf30b1baad6b3954e935157ae3f"}, + {file = "psutil-5.9.5-cp27-none-win_amd64.whl", hash = "sha256:8c5f7c5a052d1d567db4ddd231a9d27a74e8e4a9c3f44b1032762bd7b9fdcd42"}, + {file = "psutil-5.9.5-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:3c6f686f4225553615612f6d9bc21f1c0e305f75d7d8454f9b46e901778e7217"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7a7dd9997128a0d928ed4fb2c2d57e5102bb6089027939f3b722f3a210f9a8da"}, + {file = "psutil-5.9.5-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89518112647f1276b03ca97b65cc7f64ca587b1eb0278383017c2a0dcc26cbe4"}, + {file = "psutil-5.9.5-cp36-abi3-win32.whl", hash = "sha256:104a5cc0e31baa2bcf67900be36acde157756b9c44017b86b2c049f11957887d"}, + {file = "psutil-5.9.5-cp36-abi3-win_amd64.whl", hash = "sha256:b258c0c1c9d145a1d5ceffab1134441c4c5113b2417fafff7315a917a026c3c9"}, + {file = "psutil-5.9.5-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c607bb3b57dc779d55e1554846352b4e358c10fff3abf3514a7a6601beebdb30"}, + {file = "psutil-5.9.5.tar.gz", hash = "sha256:5410638e4df39c54d957fc51ce03048acd8e6d60abc0f5107af51e5fb566eb3c"}, +] + +[package.extras] +test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] + [[package]] name = "psycopg" version = "3.1.12" @@ -8095,6 +8121,7 @@ elasticsearch = ["elasticsearch"] github = ["PyGithub", "gitpython"] gmail = ["google-api-core", "google-api-python-client", "google-auth", "google-auth-httplib2", "google-auth-oauthlib", "requests"] google = ["google-generativeai"] +googledrive = ["google-api-python-client", "google-auth-httplib2", "google-auth-oauthlib"] huggingface-hub = ["huggingface_hub"] llama2 = ["replicate"] milvus = ["pymilvus"] diff --git a/pyproject.toml b/pyproject.toml index b4b1fdd0..e642799e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -197,6 +197,7 @@ gmail = [ "google-auth-httplib2", "google-api-core", ] +googledrive = ["google-api-python-client", "google-auth-oauthlib", "google-auth-httplib2"] postgres = ["psycopg", "psycopg-binary", "psycopg-pool"] mysql = ["mysql-connector-python"] github = ["PyGithub", "gitpython"] diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py index e0de9958..cdd06fcc 100644 --- a/tests/chunkers/test_chunkers.py +++ b/tests/chunkers/test_chunkers.py @@ -3,6 +3,7 @@ from embedchain.chunkers.discourse import DiscourseChunker from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docx_file import DocxFileChunker from embedchain.chunkers.gmail import GmailChunker +from embedchain.chunkers.google_drive import GoogleDriveChunker from embedchain.chunkers.json import JSONChunker from embedchain.chunkers.mdx import MdxChunker from embedchain.chunkers.notion import NotionChunker @@ -41,6 +42,7 @@ chunker_common_config = { SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len}, + GoogleDriveChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, } diff --git a/tests/loaders/test_google_drive.py b/tests/loaders/test_google_drive.py new file mode 100644 index 00000000..00d8bb1c --- /dev/null +++ b/tests/loaders/test_google_drive.py @@ -0,0 +1,37 @@ +import pytest + +from embedchain.loaders.google_drive import GoogleDriveLoader + + +@pytest.fixture +def google_drive_folder_loader(): + return GoogleDriveLoader() + + +def test_load_data_invalid_drive_url(google_drive_folder_loader): + mock_invalid_drive_url = "https://example.com" + with pytest.raises( + ValueError, + match="The url provided https://example.com does not match a google drive folder url. Example " + "drive url: https://drive.google.com/drive/u/0/folders/xxxx", + ): + google_drive_folder_loader.load_data(mock_invalid_drive_url) + + +@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.") +def test_load_data_incorrect_drive_url(google_drive_folder_loader): + mock_invalid_drive_url = "https://drive.google.com/drive/u/0/folders/xxxx" + with pytest.raises( + FileNotFoundError, match="Unable to locate folder or files, check provided drive URL and try again" + ): + google_drive_folder_loader.load_data(mock_invalid_drive_url) + + +@pytest.mark.skip(reason="This test won't work unless google api credentials are properly setup.") +def test_load_data(google_drive_folder_loader): + mock_valid_url = "YOUR_VALID_URL" + result = google_drive_folder_loader.load_data(mock_valid_url) + assert "doc_id" in result + assert "data" in result + assert "content" in result["data"][0] + assert "meta_data" in result["data"][0]