[feat]: Add support for XML file format (#757)
This commit is contained in:
13
docs/data-sources/xml.mdx
Normal file
13
docs/data-sources/xml.mdx
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
---
|
||||||
|
title: 'XML File'
|
||||||
|
---
|
||||||
|
|
||||||
|
### XML file
|
||||||
|
|
||||||
|
To add any xml file, use the data_type as `xml`. Eg:
|
||||||
|
|
||||||
|
```python
|
||||||
|
app.add('content/data.xml')
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: Only the text content of the xml file will be added to the app. The tags will be ignored.
|
||||||
22
embedchain/chunkers/xml.py
Normal file
22
embedchain/chunkers/xml.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
|
from embedchain.config.add_config import ChunkerConfig
|
||||||
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class XmlChunker(BaseChunker):
|
||||||
|
"""Chunker for XML files."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
|
if config is None:
|
||||||
|
config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len)
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
|
super().__init__(text_splitter)
|
||||||
@@ -9,6 +9,7 @@ from embedchain.chunkers.qna_pair import QnaPairChunker
|
|||||||
from embedchain.chunkers.table import TableChunker
|
from embedchain.chunkers.table import TableChunker
|
||||||
from embedchain.chunkers.text import TextChunker
|
from embedchain.chunkers.text import TextChunker
|
||||||
from embedchain.chunkers.web_page import WebPageChunker
|
from embedchain.chunkers.web_page import WebPageChunker
|
||||||
|
from embedchain.chunkers.xml import XmlChunker
|
||||||
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
||||||
from embedchain.config import AddConfig
|
from embedchain.config import AddConfig
|
||||||
from embedchain.config.add_config import ChunkerConfig, LoaderConfig
|
from embedchain.config.add_config import ChunkerConfig, LoaderConfig
|
||||||
@@ -24,6 +25,7 @@ from embedchain.loaders.mdx import MdxLoader
|
|||||||
from embedchain.loaders.pdf_file import PdfFileLoader
|
from embedchain.loaders.pdf_file import PdfFileLoader
|
||||||
from embedchain.loaders.sitemap import SitemapLoader
|
from embedchain.loaders.sitemap import SitemapLoader
|
||||||
from embedchain.loaders.web_page import WebPageLoader
|
from embedchain.loaders.web_page import WebPageLoader
|
||||||
|
from embedchain.loaders.xml import XmlLoader
|
||||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
|
||||||
@@ -67,6 +69,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.TEXT: LocalTextLoader,
|
DataType.TEXT: LocalTextLoader,
|
||||||
DataType.DOCX: DocxFileLoader,
|
DataType.DOCX: DocxFileLoader,
|
||||||
DataType.SITEMAP: SitemapLoader,
|
DataType.SITEMAP: SitemapLoader,
|
||||||
|
DataType.XML: XmlLoader,
|
||||||
DataType.DOCS_SITE: DocsSiteLoader,
|
DataType.DOCS_SITE: DocsSiteLoader,
|
||||||
DataType.CSV: CsvLoader,
|
DataType.CSV: CsvLoader,
|
||||||
DataType.MDX: MdxLoader,
|
DataType.MDX: MdxLoader,
|
||||||
@@ -110,6 +113,7 @@ class DataFormatter(JSONSerializable):
|
|||||||
DataType.CSV: TableChunker,
|
DataType.CSV: TableChunker,
|
||||||
DataType.MDX: MdxChunker,
|
DataType.MDX: MdxChunker,
|
||||||
DataType.IMAGES: ImagesChunker,
|
DataType.IMAGES: ImagesChunker,
|
||||||
|
DataType.XML: XmlChunker,
|
||||||
}
|
}
|
||||||
if data_type in chunker_classes:
|
if data_type in chunker_classes:
|
||||||
chunker_class: type = chunker_classes[data_type]
|
chunker_class: type = chunker_classes[data_type]
|
||||||
|
|||||||
26
embedchain/loaders/xml.py
Normal file
26
embedchain/loaders/xml.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import hashlib
|
||||||
|
|
||||||
|
from langchain.document_loaders import UnstructuredXMLLoader
|
||||||
|
|
||||||
|
from embedchain.helper.json_serializable import register_deserializable
|
||||||
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
from embedchain.utils import clean_string
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class XmlLoader(BaseLoader):
|
||||||
|
def load_data(self, xml_url):
|
||||||
|
"""Load data from a XML file."""
|
||||||
|
loader = UnstructuredXMLLoader(xml_url)
|
||||||
|
data = loader.load()
|
||||||
|
content = data[0].page_content
|
||||||
|
content = clean_string(content)
|
||||||
|
meta_data = data[0].metadata
|
||||||
|
meta_data["url"] = meta_data["source"]
|
||||||
|
del meta_data["source"]
|
||||||
|
output = [{"content": content, "meta_data": meta_data}]
|
||||||
|
doc_id = hashlib.sha256((content + xml_url).encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"data": output,
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ class IndirectDataType(Enum):
|
|||||||
PDF_FILE = "pdf_file"
|
PDF_FILE = "pdf_file"
|
||||||
WEB_PAGE = "web_page"
|
WEB_PAGE = "web_page"
|
||||||
SITEMAP = "sitemap"
|
SITEMAP = "sitemap"
|
||||||
|
XML = "xml"
|
||||||
DOCX = "docx"
|
DOCX = "docx"
|
||||||
DOCS_SITE = "docs_site"
|
DOCS_SITE = "docs_site"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
@@ -40,6 +41,7 @@ class DataType(Enum):
|
|||||||
PDF_FILE = IndirectDataType.PDF_FILE.value
|
PDF_FILE = IndirectDataType.PDF_FILE.value
|
||||||
WEB_PAGE = IndirectDataType.WEB_PAGE.value
|
WEB_PAGE = IndirectDataType.WEB_PAGE.value
|
||||||
SITEMAP = IndirectDataType.SITEMAP.value
|
SITEMAP = IndirectDataType.SITEMAP.value
|
||||||
|
XML = IndirectDataType.XML.value
|
||||||
DOCX = IndirectDataType.DOCX.value
|
DOCX = IndirectDataType.DOCX.value
|
||||||
DOCS_SITE = IndirectDataType.DOCS_SITE.value
|
DOCS_SITE = IndirectDataType.DOCS_SITE.value
|
||||||
NOTION = IndirectDataType.NOTION.value
|
NOTION = IndirectDataType.NOTION.value
|
||||||
|
|||||||
@@ -128,8 +128,7 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
formatted_source = format_source(str(source), 30)
|
formatted_source = format_source(str(source), 30)
|
||||||
|
|
||||||
if url:
|
if url:
|
||||||
from langchain.document_loaders.youtube import \
|
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
|
||||||
|
|
||||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||||
@@ -190,6 +189,10 @@ def detect_datatype(source: Any) -> DataType:
|
|||||||
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
|
||||||
return DataType.CSV
|
return DataType.CSV
|
||||||
|
|
||||||
|
if source.endswith(".xml"):
|
||||||
|
logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
|
||||||
|
return DataType.XML
|
||||||
|
|
||||||
# If the source is a valid file, that's not detectable as a type, an error is raised.
|
# If the source is a valid file, that's not detectable as a type, an error is raised.
|
||||||
# It does not fallback to text.
|
# It does not fallback to text.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ fastapi-poe = { version = "0.0.16", optional = true }
|
|||||||
discord = { version = "^2.3.2", optional = true }
|
discord = { version = "^2.3.2", optional = true }
|
||||||
slack-sdk = { version = "3.21.3", optional = true }
|
slack-sdk = { version = "3.21.3", optional = true }
|
||||||
docx2txt = "^0.8"
|
docx2txt = "^0.8"
|
||||||
|
unstructured = {extras = ["local-inference"], version = "^0.10.18"}
|
||||||
pillow = { version = "10.0.1", optional = true }
|
pillow = { version = "10.0.1", optional = true }
|
||||||
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
|
torchvision = { version = ">=0.15.1, !=0.15.2", optional = true }
|
||||||
ftfy = { version = "6.1.1", optional = true }
|
ftfy = { version = "6.1.1", optional = true }
|
||||||
|
|||||||
62
tests/loaders/test_xml.py
Normal file
62
tests/loaders/test_xml.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from embedchain.loaders.xml import XmlLoader
|
||||||
|
|
||||||
|
# Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml
|
||||||
|
SAMPLE_XML = """<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<factbook>
|
||||||
|
<country>
|
||||||
|
<name>United States</name>
|
||||||
|
<capital>Washington, DC</capital>
|
||||||
|
<leader>Joe Biden</leader>
|
||||||
|
<sport>Baseball</sport>
|
||||||
|
</country>
|
||||||
|
<country>
|
||||||
|
<name>Canada</name>
|
||||||
|
<capital>Ottawa</capital>
|
||||||
|
<leader>Justin Trudeau</leader>
|
||||||
|
<sport>Hockey</sport>
|
||||||
|
</country>
|
||||||
|
<country>
|
||||||
|
<name>France</name>
|
||||||
|
<capital>Paris</capital>
|
||||||
|
<leader>Emmanuel Macron</leader>
|
||||||
|
<sport>Soccer</sport>
|
||||||
|
</country>
|
||||||
|
<country>
|
||||||
|
<name>Trinidad & Tobado</name>
|
||||||
|
<capital>Port of Spain</capital>
|
||||||
|
<leader>Keith Rowley</leader>
|
||||||
|
<sport>Track & Field</sport>
|
||||||
|
</country>
|
||||||
|
</factbook>"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("xml", [SAMPLE_XML])
|
||||||
|
def test_load_data(xml: str):
|
||||||
|
"""
|
||||||
|
Test XML loader
|
||||||
|
|
||||||
|
Tests that XML file is loaded, metadata is correct and content is correct
|
||||||
|
"""
|
||||||
|
# Creating temporary XML file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+") as tmpfile:
|
||||||
|
tmpfile.write(xml)
|
||||||
|
|
||||||
|
tmpfile.seek(0)
|
||||||
|
filename = tmpfile.name
|
||||||
|
|
||||||
|
# Loading CSV using XmlLoader
|
||||||
|
loader = XmlLoader()
|
||||||
|
result = loader.load_data(filename)
|
||||||
|
data = result["data"]
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert len(data) == 1
|
||||||
|
assert "United States Washington, DC Joe Biden" in data[0]["content"]
|
||||||
|
assert "Canada Ottawa Justin Trudeau" in data[0]["content"]
|
||||||
|
assert "France Paris Emmanuel Macron" in data[0]["content"]
|
||||||
|
assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"]
|
||||||
|
assert data[0]["meta_data"]["url"] == filename
|
||||||
Reference in New Issue
Block a user