From 512cfc946663f693771c271308f99bab32f61888 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Tue, 5 Dec 2023 00:55:33 -0800 Subject: [PATCH] [Improvement] customize add method (#988) --- docs/data-sources/custom.mdx | 41 ++++++++++++++++ docs/data-sources/overview.mdx | 1 + embedchain/chunkers/common_chunker.py | 2 +- embedchain/data_formatter/data_formatter.py | 52 ++++++++------------- embedchain/embedchain.py | 8 ++-- embedchain/models/data_type.py | 12 +---- tests/chunkers/test_chunkers.py | 2 +- 7 files changed, 70 insertions(+), 48 deletions(-) create mode 100644 docs/data-sources/custom.mdx diff --git a/docs/data-sources/custom.mdx b/docs/data-sources/custom.mdx new file mode 100644 index 00000000..5c8b9096 --- /dev/null +++ b/docs/data-sources/custom.mdx @@ -0,0 +1,41 @@ +--- +title: '⚙️ Custom' +--- + +When we say "custom", we mean that you can customize the loader and chunker to your needs. This is done by passing a custom loader and chunker to the `add` method. + +```python +from embedchain import Pipeline as App +import your_loader +import your_chunker + +app = App() +loader = your_loader() +chunker = your_chunker() + +app.add("source", data_type="custom", loader=loader, chunker=chunker) +``` + + + The custom loader and chunker must be a class that inherits from the [`BaseLoader`](https://github.com/embedchain/embedchain/blob/main/embedchain/loaders/base_loader.py) and [`BaseChunker`](https://github.com/embedchain/embedchain/blob/main/embedchain/chunkers/base_chunker.py) classes respectively. + + + + If the `data_type` is not a valid data type, the `add` method will fallback to the `custom` data type and expect a custom loader and chunker to be passed by the user. + + +Example: + +```python +from embedchain import Pipeline as App +from embedchain.loaders.github import GithubLoader + +app = App() + +loader = GithubLoader(config={"token": "ghp_xxx"}) + +app.add("repo:embedchain/embedchain type:repo", data_type="github", loader=loader) + +app.query("What is Embedchain?") +# Answer: Embedchain is a Data Platform for Large Language Models (LLMs). It allows users to seamlessly load, index, retrieve, and sync unstructured data in order to build dynamic, LLM-powered applications. There is also a JavaScript implementation called embedchain-js available on GitHub. +``` diff --git a/docs/data-sources/overview.mdx b/docs/data-sources/overview.mdx index b509e7d0..a35d76da 100644 --- a/docs/data-sources/overview.mdx +++ b/docs/data-sources/overview.mdx @@ -26,6 +26,7 @@ Embedchain comes with built-in support for various data sources. We handle the c +
diff --git a/embedchain/chunkers/common_chunker.py b/embedchain/chunkers/common_chunker.py index 1527e339..53676d40 100644 --- a/embedchain/chunkers/common_chunker.py +++ b/embedchain/chunkers/common_chunker.py @@ -13,7 +13,7 @@ class CommonChunker(BaseChunker): def __init__(self, config: Optional[ChunkerConfig] = None): if config is None: - config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len) text_splitter = RecursiveCharacterTextSplitter( chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap, diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 73f929f7..fc985efd 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -68,23 +68,13 @@ class DataFormatter(JSONSerializable): DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader", } - custom_loaders = set( - [ - DataType.POSTGRES, - DataType.MYSQL, - DataType.SLACK, - DataType.DISCOURSE, - DataType.GITHUB, - ] - ) - - if data_type in loaders: + if data_type == DataType.CUSTOM or ("loader" in kwargs): + loader_class: type = kwargs.get("loader", None) + if loader_class: + return loader_class + elif data_type in loaders: loader_class: type = self._lazy_load(loaders[data_type]) return loader_class() - elif data_type in custom_loaders: - loader_class: type = kwargs.get("loader", None) - if loader_class is not None: - return loader_class raise ValueError( f"Cant find the loader for {data_type}.\ @@ -112,28 +102,26 @@ class DataFormatter(JSONSerializable): DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker", DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker", DataType.NOTION: "embedchain.chunkers.notion.NotionChunker", - DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker", - DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker", - DataType.SLACK: "embedchain.chunkers.slack.SlackChunker", - DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker", DataType.SUBSTACK: "embedchain.chunkers.substack.SubstackChunker", - DataType.GITHUB: "embedchain.chunkers.common_chunker.CommonChunker", DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.common_chunker.CommonChunker", DataType.DISCORD: "embedchain.chunkers.common_chunker.CommonChunker", + DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker", } - if data_type in chunker_classes: - if "chunker" in kwargs: - chunker_class = kwargs.get("chunker") - else: - chunker_class = self._lazy_load(chunker_classes[data_type]) - + if "chunker" in kwargs: + chunker_class = kwargs.get("chunker", None) + if chunker_class: + chunker = chunker_class(config) + chunker.set_data_type(data_type) + return chunker + elif data_type in chunker_classes: + chunker_class = self._lazy_load(chunker_classes[data_type]) chunker = chunker_class(config) chunker.set_data_type(data_type) return chunker - else: - raise ValueError( - f"Cant find the chunker for {data_type}.\ - We recommend to pass the chunker to use data_type: {data_type},\ - check `https://docs.embedchain.ai/data-sources/overview`." - ) + + raise ValueError( + f"Cant find the chunker for {data_type}.\ + We recommend to pass the chunker to use data_type: {data_type},\ + check `https://docs.embedchain.ai/data-sources/overview`." + ) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index dfd1627e..6c42a13a 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -178,10 +178,10 @@ class EmbedChain(JSONSerializable): try: data_type = DataType(data_type) except ValueError: - raise ValueError( - f"Invalid data_type: '{data_type}'.", - f"Please use one of the following: {[data_type.value for data_type in DataType]}", - ) from None + logging.info( + f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`" # noqa: E501 + ) + data_type = DataType.CUSTOM if not data_type: data_type = detect_datatype(source) diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 647d8ac5..a56f7f8d 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -29,14 +29,10 @@ class IndirectDataType(Enum): JSON = "json" OPENAPI = "openapi" GMAIL = "gmail" - POSTGRES = "postgres" - MYSQL = "mysql" - SLACK = "slack" - DISCOURSE = "discourse" SUBSTACK = "substack" - GITHUB = "github" YOUTUBE_CHANNEL = "youtube_channel" DISCORD = "discord" + CUSTOM = "custom" class SpecialDataType(Enum): @@ -65,11 +61,7 @@ class DataType(Enum): JSON = IndirectDataType.JSON.value OPENAPI = IndirectDataType.OPENAPI.value GMAIL = IndirectDataType.GMAIL.value - POSTGRES = IndirectDataType.POSTGRES.value - MYSQL = IndirectDataType.MYSQL.value - SLACK = IndirectDataType.SLACK.value - DISCOURSE = IndirectDataType.DISCOURSE.value SUBSTACK = IndirectDataType.SUBSTACK.value - GITHUB = IndirectDataType.GITHUB.value YOUTUBE_CHANNEL = IndirectDataType.YOUTUBE_CHANNEL.value DISCORD = IndirectDataType.DISCORD.value + CUSTOM = IndirectDataType.CUSTOM.value diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py index c70c8ff4..1e725a8a 100644 --- a/tests/chunkers/test_chunkers.py +++ b/tests/chunkers/test_chunkers.py @@ -40,7 +40,7 @@ chunker_common_config = { PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, - CommonChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, + CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len}, }