[feat]: Add support for XML file format (#757)

This commit is contained in:
Ojuswi Rastogi
2023-10-07 04:09:32 +05:30
committed by GitHub
parent d2fd3ce434
commit 540a0a3685
8 changed files with 135 additions and 2 deletions

View File

@@ -0,0 +1,22 @@
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable
@register_deserializable
class XmlChunker(BaseChunker):
"""Chunker for XML files."""
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
length_function=config.length_function,
)
super().__init__(text_splitter)

View File

@@ -9,6 +9,7 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.table import TableChunker
from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.xml import XmlChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.config import AddConfig
from embedchain.config.add_config import ChunkerConfig, LoaderConfig
@@ -24,6 +25,7 @@ from embedchain.loaders.mdx import MdxLoader
from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.sitemap import SitemapLoader
from embedchain.loaders.web_page import WebPageLoader
from embedchain.loaders.xml import XmlLoader
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.models.data_type import DataType
@@ -67,6 +69,7 @@ class DataFormatter(JSONSerializable):
DataType.TEXT: LocalTextLoader,
DataType.DOCX: DocxFileLoader,
DataType.SITEMAP: SitemapLoader,
DataType.XML: XmlLoader,
DataType.DOCS_SITE: DocsSiteLoader,
DataType.CSV: CsvLoader,
DataType.MDX: MdxLoader,
@@ -110,6 +113,7 @@ class DataFormatter(JSONSerializable):
DataType.CSV: TableChunker,
DataType.MDX: MdxChunker,
DataType.IMAGES: ImagesChunker,
DataType.XML: XmlChunker,
}
if data_type in chunker_classes:
chunker_class: type = chunker_classes[data_type]

26
embedchain/loaders/xml.py Normal file
View File

@@ -0,0 +1,26 @@
import hashlib
from langchain.document_loaders import UnstructuredXMLLoader
from embedchain.helper.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
@register_deserializable
class XmlLoader(BaseLoader):
def load_data(self, xml_url):
"""Load data from a XML file."""
loader = UnstructuredXMLLoader(xml_url)
data = loader.load()
content = data[0].page_content
content = clean_string(content)
meta_data = data[0].metadata
meta_data["url"] = meta_data["source"]
del meta_data["source"]
output = [{"content": content, "meta_data": meta_data}]
doc_id = hashlib.sha256((content + xml_url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}

View File

@@ -18,6 +18,7 @@ class IndirectDataType(Enum):
PDF_FILE = "pdf_file"
WEB_PAGE = "web_page"
SITEMAP = "sitemap"
XML = "xml"
DOCX = "docx"
DOCS_SITE = "docs_site"
NOTION = "notion"
@@ -40,6 +41,7 @@ class DataType(Enum):
PDF_FILE = IndirectDataType.PDF_FILE.value
WEB_PAGE = IndirectDataType.WEB_PAGE.value
SITEMAP = IndirectDataType.SITEMAP.value
XML = IndirectDataType.XML.value
DOCX = IndirectDataType.DOCX.value
DOCS_SITE = IndirectDataType.DOCS_SITE.value
NOTION = IndirectDataType.NOTION.value

View File

@@ -128,8 +128,7 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30)
if url:
from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -190,6 +189,10 @@ def detect_datatype(source: Any) -> DataType:
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
return DataType.CSV
if source.endswith(".xml"):
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
return DataType.XML
# If the source is a valid file, that's not detectable as a type, an error is raised.
# It does not fallback to text.
raise ValueError(