diff --git a/docs/data-sources/xml.mdx b/docs/data-sources/xml.mdx new file mode 100644 index 00000000..0398afdc --- /dev/null +++ b/docs/data-sources/xml.mdx @@ -0,0 +1,13 @@ +--- +title: 'XML File' +--- + +### XML file + +To add any xml file, use the data_type as `xml`. Eg: + +```python +app.add('content/data.xml') +``` + +Note: Only the text content of the xml file will be added to the app. The tags will be ignored. \ No newline at end of file diff --git a/embedchain/chunkers/xml.py b/embedchain/chunkers/xml.py new file mode 100644 index 00000000..cae519ab --- /dev/null +++ b/embedchain/chunkers/xml.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class XmlChunker(BaseChunker): + """Chunker for XML files.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=500, chunk_overlap=50, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 10829474..635f4ea7 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -9,6 +9,7 @@ from embedchain.chunkers.qna_pair import QnaPairChunker from embedchain.chunkers.table import TableChunker from embedchain.chunkers.text import TextChunker from embedchain.chunkers.web_page import WebPageChunker +from embedchain.chunkers.xml import XmlChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.config import AddConfig from embedchain.config.add_config import ChunkerConfig, LoaderConfig @@ -24,6 +25,7 @@ from embedchain.loaders.mdx import MdxLoader from embedchain.loaders.pdf_file import PdfFileLoader from embedchain.loaders.sitemap import SitemapLoader from embedchain.loaders.web_page import WebPageLoader +from embedchain.loaders.xml import XmlLoader from embedchain.loaders.youtube_video import YoutubeVideoLoader from embedchain.models.data_type import DataType @@ -67,6 +69,7 @@ class DataFormatter(JSONSerializable): DataType.TEXT: LocalTextLoader, DataType.DOCX: DocxFileLoader, DataType.SITEMAP: SitemapLoader, + DataType.XML: XmlLoader, DataType.DOCS_SITE: DocsSiteLoader, DataType.CSV: CsvLoader, DataType.MDX: MdxLoader, @@ -110,6 +113,7 @@ class DataFormatter(JSONSerializable): DataType.CSV: TableChunker, DataType.MDX: MdxChunker, DataType.IMAGES: ImagesChunker, + DataType.XML: XmlChunker, } if data_type in chunker_classes: chunker_class: type = chunker_classes[data_type] diff --git a/embedchain/loaders/xml.py b/embedchain/loaders/xml.py new file mode 100644 index 00000000..324a8c71 --- /dev/null +++ b/embedchain/loaders/xml.py @@ -0,0 +1,26 @@ +import hashlib + +from langchain.document_loaders import UnstructuredXMLLoader + +from embedchain.helper.json_serializable import register_deserializable +from embedchain.loaders.base_loader import BaseLoader +from embedchain.utils import clean_string + + +@register_deserializable +class XmlLoader(BaseLoader): + def load_data(self, xml_url): + """Load data from a XML file.""" + loader = UnstructuredXMLLoader(xml_url) + data = loader.load() + content = data[0].page_content + content = clean_string(content) + meta_data = data[0].metadata + meta_data["url"] = meta_data["source"] + del meta_data["source"] + output = [{"content": content, "meta_data": meta_data}] + doc_id = hashlib.sha256((content + xml_url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": output, + } diff --git a/embedchain/models/data_type.py b/embedchain/models/data_type.py index 90d7dd91..566fe657 100644 --- a/embedchain/models/data_type.py +++ b/embedchain/models/data_type.py @@ -18,6 +18,7 @@ class IndirectDataType(Enum): PDF_FILE = "pdf_file" WEB_PAGE = "web_page" SITEMAP = "sitemap" + XML = "xml" DOCX = "docx" DOCS_SITE = "docs_site" NOTION = "notion" @@ -40,6 +41,7 @@ class DataType(Enum): PDF_FILE = IndirectDataType.PDF_FILE.value WEB_PAGE = IndirectDataType.WEB_PAGE.value SITEMAP = IndirectDataType.SITEMAP.value + XML = IndirectDataType.XML.value DOCX = IndirectDataType.DOCX.value DOCS_SITE = IndirectDataType.DOCS_SITE.value NOTION = IndirectDataType.NOTION.value diff --git a/embedchain/utils.py b/embedchain/utils.py index 43c5ef69..9b5709c3 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -128,8 +128,7 @@ def detect_datatype(source: Any) -> DataType: formatted_source = format_source(str(source), 30) if url: - from langchain.document_loaders.youtube import \ - ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS + from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS if url.netloc in YOUTUBE_ALLOWED_NETLOCS: logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.") @@ -190,6 +189,10 @@ def detect_datatype(source: Any) -> DataType: logging.debug(f"Source of `{formatted_source}` detected as `csv`.") return DataType.CSV + if source.endswith(".xml"): + logging.debug(f"Source of `{formatted_source}` detected as `xml`.") + return DataType.XML + # If the source is a valid file, that's not detectable as a type, an error is raised. # It does not fallback to text. raise ValueError( diff --git a/pyproject.toml b/pyproject.toml index dbc960b0..758d3705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ fastapi-poe = { version = "0.0.16", optional = true } discord = { version = "^2.3.2", optional = true } slack-sdk = { version = "3.21.3", optional = true } docx2txt = "^0.8" +unstructured = {extras = ["local-inference"], version = "^0.10.18"} pillow = { version = "10.0.1", optional = true } torchvision = { version = ">=0.15.1, !=0.15.2", optional = true } ftfy = { version = "6.1.1", optional = true } diff --git a/tests/loaders/test_xml.py b/tests/loaders/test_xml.py new file mode 100644 index 00000000..d1ff5daa --- /dev/null +++ b/tests/loaders/test_xml.py @@ -0,0 +1,62 @@ +import tempfile + +import pytest + +from embedchain.loaders.xml import XmlLoader + +# Taken from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/tests/integration_tests/examples/factbook.xml +SAMPLE_XML = """ + + + United States + Washington, DC + Joe Biden + Baseball + + + Canada + Ottawa + Justin Trudeau + Hockey + + + France + Paris + Emmanuel Macron + Soccer + + + Trinidad & Tobado + Port of Spain + Keith Rowley + Track & Field + +""" + + +@pytest.mark.parametrize("xml", [SAMPLE_XML]) +def test_load_data(xml: str): + """ + Test XML loader + + Tests that XML file is loaded, metadata is correct and content is correct + """ + # Creating temporary XML file + with tempfile.NamedTemporaryFile(mode="w+") as tmpfile: + tmpfile.write(xml) + + tmpfile.seek(0) + filename = tmpfile.name + + # Loading CSV using XmlLoader + loader = XmlLoader() + result = loader.load_data(filename) + data = result["data"] + + # Assertions + assert len(data) == 1 + assert "United States Washington, DC Joe Biden" in data[0]["content"] + assert "Canada Ottawa Justin Trudeau" in data[0]["content"] + assert "France Paris Emmanuel Macron" in data[0]["content"] + assert "Trinidad & Tobado Port of Spain Keith Rowley" in data[0]["content"] + assert data[0]["meta_data"]["url"] == filename