From 36b26e08c33778ba2e578583eb9c4638a05bd685 Mon Sep 17 00:00:00 2001 From: Taranjeet Singh Date: Tue, 12 Sep 2023 16:43:18 -0700 Subject: [PATCH] feat: add support for mdx file (#604) --- docs/advanced/data_types.mdx | 8 ++++++ embedchain/chunkers/mdx.py | 22 ++++++++++++++++ embedchain/data_formatter/data_formatter.py | 4 +++ embedchain/loaders/mdx.py | 28 +++++++++++++++++++++ embedchain/models/data_type.py | 1 + 5 files changed, 63 insertions(+) create mode 100644 embedchain/chunkers/mdx.py create mode 100644 embedchain/loaders/mdx.py diff --git a/docs/advanced/data_types.mdx b/docs/advanced/data_types.mdx index 7ce176a7..800e4922 100644 --- a/docs/advanced/data_types.mdx +++ b/docs/advanced/data_types.mdx @@ -102,6 +102,14 @@ app.add("my-page-cfbc134ca6464fc980d0391613959196", "notion") app.add("https://www.notion.so/my-page-cfbc134ca6464fc980d0391613959196", "notion") ``` +### Mdx file + +To add any mdx file to your app, use the data_type (first argument to `.add()` method) as `mdx`. Note that this supports support mdx file present on machine, so this should be a file path. Eg: + +```python +app.add('path/to/file.mdx', data_type='mdx') +``` + ## Local Data Types ### Text diff --git a/embedchain/chunkers/mdx.py b/embedchain/chunkers/mdx.py new file mode 100644 index 00000000..c4c01521 --- /dev/null +++ b/embedchain/chunkers/mdx.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class MdxChunker(BaseChunker): + """Chunker for mdx files.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index e8db93d1..a4551676 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -1,6 +1,7 @@ from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.docs_site import DocsSiteChunker from embedchain.chunkers.docx_file import DocxFileChunker +from embedchain.chunkers.mdx import MdxChunker from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.qna_pair import QnaPairChunker @@ -17,6 +18,7 @@ from embedchain.loaders.docs_site_loader import DocsSiteLoader from embedchain.loaders.docx_file import DocxFileLoader from embedchain.loaders.local_qna_pair import LocalQnaPairLoader from embedchain.loaders.local_text import LocalTextLoader +from embedchain.loaders.mdx import MdxLoader from embedchain.loaders.pdf_file import PdfFileLoader from embedchain.loaders.sitemap import SitemapLoader from embedchain.loaders.web_page import WebPageLoader @@ -65,6 +67,7 @@ class DataFormatter(JSONSerializable): DataType.SITEMAP: SitemapLoader, DataType.DOCS_SITE: DocsSiteLoader, DataType.CSV: CsvLoader, + DataType.MDX: MdxLoader, } lazy_loaders = {DataType.NOTION} if data_type in loaders: @@ -103,6 +106,7 @@ class DataFormatter(JSONSerializable): DataType.DOCS_SITE: DocsSiteChunker, DataType.NOTION: NotionChunker, DataType.CSV: TableChunker, + DataType.MDX: MdxChunker, } if data_type in chunker_classes: chunker_class: type = chunker_classes[data_type] diff --git a/embedchain/loaders/mdx.py b/embedchain/loaders/mdx.py new file mode 100644 index 00000000..df1bd26c --- /dev/null +++ b/embedchain/loaders/mdx.py @@ -0,0 +1,28 @@ +import hashlib + +from langchain.document_loaders import PyPDFLoader + +from embedchain.helper.json_serializable import register_deserializable +from embedchain.loaders.base_loader import BaseLoader +from embedchain.utils import clean_string + + +@register_deserializable +class MdxLoader(BaseLoader): + def load_data(self, url): + """Load data from a mdx file.""" + with open(url, 'r', encoding="utf-8") as infile: + content = infile.read() + meta_data = { + "url": url, + } + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": [ + { + "content": content, + "meta_data": meta_data, + } + ], + } diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 7d8d5232..8c13eb06 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -12,3 +12,4 @@ class DataType(Enum): QNA_PAIR = "qna_pair" NOTION = "notion" CSV = "csv" + MDX = "mdx"