From 1741d3bef68fadaae6c4bdb65396ca8d3e3bf339 Mon Sep 17 00:00:00 2001 From: Richard Awoyemi <35015261+richawo@users.noreply.github.com> Date: Sat, 7 Oct 2023 00:24:15 +0100 Subject: [PATCH] [fix]: Fix sitemap loader (#753) --- embedchain/chunkers/sitemap.py | 22 +++++++++++++++++++++ embedchain/data_formatter/data_formatter.py | 2 ++ embedchain/loaders/sitemap.py | 5 ++--- tests/embedchain/test_add.py | 8 ++++++++ 4 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 embedchain/chunkers/sitemap.py diff --git a/embedchain/chunkers/sitemap.py b/embedchain/chunkers/sitemap.py new file mode 100644 index 00000000..64050011 --- /dev/null +++ b/embedchain/chunkers/sitemap.py @@ -0,0 +1,22 @@ +from typing import Optional + +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig +from embedchain.helper.json_serializable import register_deserializable + + +@register_deserializable +class SitemapChunker(BaseChunker): + """Chunker for sitemap.""" + + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) + super().__init__(text_splitter) diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 635f4ea7..0663541f 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -6,6 +6,7 @@ from embedchain.chunkers.mdx import MdxChunker from embedchain.chunkers.notion import NotionChunker from embedchain.chunkers.pdf_file import PdfFileChunker from embedchain.chunkers.qna_pair import QnaPairChunker +from embedchain.chunkers.sitemap import SitemapChunker from embedchain.chunkers.table import TableChunker from embedchain.chunkers.text import TextChunker from embedchain.chunkers.web_page import WebPageChunker @@ -109,6 +110,7 @@ class DataFormatter(JSONSerializable): DataType.TEXT: TextChunker, DataType.DOCX: DocxFileChunker, DataType.DOCS_SITE: DocsSiteChunker, + DataType.SITEMAP: SitemapChunker, DataType.NOTION: NotionChunker, DataType.CSV: TableChunker, DataType.MDX: MdxChunker, diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index d85e8829..fb3657fa 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -36,9 +36,8 @@ class SitemapLoader(BaseLoader): for link in links: try: each_load_data = web_page_loader.load_data(link) - - if is_readable(each_load_data[0].get("content")): - output.append(each_load_data) + if is_readable(each_load_data.get("data")[0].get("content")): + output.append(each_load_data.get("data")) else: logging.warning(f"Page is not readable (too many invalid characters): {link}") except ParserRejectedMarkup as e: diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py index f8c20564..f63f60d4 100644 --- a/tests/embedchain/test_add.py +++ b/tests/embedchain/test_add.py @@ -27,6 +27,14 @@ class TestApp(unittest.TestCase): self.app.add("https://example.com", metadata={"meta": "meta-data"}) self.assertEqual(self.app.user_asks, [["https://example.com", "web_page", {"meta": "meta-data"}]]) + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) + def test_add_sitemap(self): + """ + In addition to the test_add function, this test checks that sitemaps can be added with the correct data type. + """ + self.app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"}) + self.assertEqual(self.app.user_asks, [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]) + @patch("chromadb.api.models.Collection.Collection.add", MagicMock) def test_add_forced_type(self): """