From b91d92260032c57eed21b3502cab988726a1ca56 Mon Sep 17 00:00:00 2001 From: Sidharth Mohanty Date: Mon, 9 Oct 2023 23:56:21 +0530 Subject: [PATCH] Improve tests (#780) --- embedchain/loaders/web_page.py | 35 ++++---- tests/chunkers/test_base_chunker.py | 84 ++++++++++++++++++ tests/chunkers/test_chunkers.py | 46 ++++++++++ tests/loaders/test_csv.py | 27 ++++++ tests/loaders/test_docs_site.py | 128 +++++++++++++++++++++++++++ tests/loaders/test_docx_file.py | 37 ++++++++ tests/loaders/test_local_qna_pair.py | 30 +++++++ tests/loaders/test_local_text.py | 25 ++++++ tests/loaders/test_mdx.py | 28 ++++++ tests/loaders/test_notion.py | 34 +++++++ tests/loaders/test_web_page.py | 115 ++++++++++++++++++++++++ tests/loaders/test_youtube_video.py | 47 ++++++++++ 12 files changed, 621 insertions(+), 15 deletions(-) create mode 100644 tests/chunkers/test_base_chunker.py create mode 100644 tests/chunkers/test_chunkers.py create mode 100644 tests/loaders/test_docs_site.py create mode 100644 tests/loaders/test_docx_file.py create mode 100644 tests/loaders/test_local_qna_pair.py create mode 100644 tests/loaders/test_local_text.py create mode 100644 tests/loaders/test_mdx.py create mode 100644 tests/loaders/test_notion.py create mode 100644 tests/loaders/test_web_page.py create mode 100644 tests/loaders/test_youtube_video.py diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index 53d41df0..bf0d2416 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -15,7 +15,25 @@ class WebPageLoader(BaseLoader): """Load data from a web page.""" response = requests.get(url) data = response.content - soup = BeautifulSoup(data, "html.parser") + content = self._get_clean_content(data, url) + + meta_data = { + "url": url, + } + + doc_id = hashlib.sha256((content + url).encode()).hexdigest() + return { + "doc_id": doc_id, + "data": [ + { + "content": content, + "meta_data": meta_data, + } + ], + } + + def _get_clean_content(self, html, url) -> str: + soup = BeautifulSoup(html, "html.parser") original_size = len(str(soup.get_text())) tags_to_exclude = [ @@ -61,17 +79,4 @@ class WebPageLoader(BaseLoader): f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501 ) - meta_data = { - "url": url, - } - content = content - doc_id = hashlib.sha256((content + url).encode()).hexdigest() - return { - "doc_id": doc_id, - "data": [ - { - "content": content, - "meta_data": meta_data, - } - ], - } + return content diff --git a/tests/chunkers/test_base_chunker.py b/tests/chunkers/test_base_chunker.py new file mode 100644 index 00000000..2a9deffb --- /dev/null +++ b/tests/chunkers/test_base_chunker.py @@ -0,0 +1,84 @@ +import hashlib +import pytest +from unittest.mock import MagicMock +from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.models.data_type import DataType + + +@pytest.fixture +def text_splitter_mock(): + return MagicMock() + + +@pytest.fixture +def loader_mock(): + return MagicMock() + + +@pytest.fixture +def app_id(): + return "test_app" + + +@pytest.fixture +def data_type(): + return DataType.TEXT + + +@pytest.fixture +def chunker(text_splitter_mock, data_type): + text_splitter = text_splitter_mock + chunker = BaseChunker(text_splitter) + chunker.set_data_type(data_type) + return chunker + + +def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type): + text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"] + loader_mock.load_data.return_value = { + "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}], + "doc_id": "DocID", + } + + result = chunker.create_chunks(loader_mock, "test_src", app_id) + expected_ids = [ + hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(), + hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(), + ] + + assert result["documents"] == ["Chunk 1", "Chunk 2"] + assert result["ids"] == expected_ids + assert result["metadatas"] == [ + { + "url": "URL 1", + "data_type": data_type.value, + "doc_id": f"{app_id}--DocID", + }, + { + "url": "URL 1", + "data_type": data_type.value, + "doc_id": f"{app_id}--DocID", + }, + ] + assert result["doc_id"] == f"{app_id}--DocID" + + +def test_get_chunks(chunker, text_splitter_mock): + text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"] + + content = "This is a test content." + result = chunker.get_chunks(content) + + assert len(result) == 2 + assert result == ["Chunk 1", "Chunk 2"] + + +def test_set_data_type(chunker): + chunker.set_data_type(DataType.MDX) + assert chunker.data_type == DataType.MDX + + +def test_get_word_count(chunker): + documents = ["This is a test.", "Another test."] + result = chunker.get_word_count(documents) + assert result == 6 diff --git a/tests/chunkers/test_chunkers.py b/tests/chunkers/test_chunkers.py new file mode 100644 index 00000000..b8c72adf --- /dev/null +++ b/tests/chunkers/test_chunkers.py @@ -0,0 +1,46 @@ +from embedchain.chunkers.docs_site import DocsSiteChunker +from embedchain.chunkers.docx_file import DocxFileChunker +from embedchain.chunkers.mdx import MdxChunker +from embedchain.chunkers.notion import NotionChunker +from embedchain.chunkers.pdf_file import PdfFileChunker +from embedchain.chunkers.qna_pair import QnaPairChunker +from embedchain.chunkers.sitemap import SitemapChunker +from embedchain.chunkers.table import TableChunker +from embedchain.chunkers.text import TextChunker +from embedchain.chunkers.web_page import WebPageChunker +from embedchain.chunkers.xml import XmlChunker +from embedchain.chunkers.youtube_video import YoutubeVideoChunker +from embedchain.config.add_config import ChunkerConfig + +chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len) + +chunker_common_config = { + DocsSiteChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len}, + DocxFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, + PdfFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, + TextChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len}, + MdxChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len}, + NotionChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len}, + QnaPairChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len}, + TableChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len}, + SitemapChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len}, + WebPageChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len}, + XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len}, + YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len}, +} + + +def test_default_config_values(): + for chunker_class, config in chunker_common_config.items(): + chunker = chunker_class() + assert chunker.text_splitter._chunk_size == config["chunk_size"] + assert chunker.text_splitter._chunk_overlap == config["chunk_overlap"] + assert chunker.text_splitter._length_function == config["length_function"] + + +def test_custom_config_values(): + for chunker_class, _ in chunker_common_config.items(): + chunker = chunker_class(config=chunker_config) + assert chunker.text_splitter._chunk_size == 500 + assert chunker.text_splitter._chunk_overlap == 0 + assert chunker.text_splitter._length_function == len diff --git a/tests/loaders/test_csv.py b/tests/loaders/test_csv.py index 07f06ae8..9cdcff39 100644 --- a/tests/loaders/test_csv.py +++ b/tests/loaders/test_csv.py @@ -2,6 +2,7 @@ import csv import os import pathlib import tempfile +from unittest.mock import MagicMock, patch import pytest @@ -84,3 +85,29 @@ def test_load_data_with_file_uri(delimiter): # Cleaning up the temporary file os.unlink(tmpfile.name) + + +@pytest.mark.parametrize("content", ["ftp://example.com", "sftp://example.com", "mailto://example.com"]) +def test_get_file_content(content): + with pytest.raises(ValueError): + loader = CsvLoader() + loader._get_file_content(content) + + +@pytest.mark.parametrize("content", ["http://example.com", "https://example.com"]) +def test_get_file_content_http(content): + """ + Test _get_file_content method of CsvLoader for http and https URLs + """ + + with patch("requests.get") as mock_get: + mock_response = MagicMock() + mock_response.text = "Name,Age,Occupation\nAlice,28,Engineer\nBob,35,Doctor\nCharlie,22,Student" + mock_get.return_value = mock_response + + loader = CsvLoader() + file_content = loader._get_file_content(content) + + mock_get.assert_called_once_with(content) + mock_response.raise_for_status.assert_called_once() + assert file_content.read() == mock_response.text diff --git a/tests/loaders/test_docs_site.py b/tests/loaders/test_docs_site.py new file mode 100644 index 00000000..e27bd1bf --- /dev/null +++ b/tests/loaders/test_docs_site.py @@ -0,0 +1,128 @@ +import hashlib +import pytest +from unittest.mock import Mock, patch +from requests import Response +from embedchain.loaders.docs_site_loader import DocsSiteLoader + + +@pytest.fixture +def mock_requests_get(): + with patch("requests.get") as mock_get: + yield mock_get + + +@pytest.fixture +def docs_site_loader(): + return DocsSiteLoader() + + +def test_get_child_links_recursive(mock_requests_get, docs_site_loader): + mock_response = Mock() + mock_response.status_code = 200 + mock_response.text = """ + + Page 1 + Page 2 + + """ + mock_requests_get.return_value = mock_response + + docs_site_loader._get_child_links_recursive("https://example.com") + + assert len(docs_site_loader.visited_links) == 2 + assert "https://example.com/page1" in docs_site_loader.visited_links + assert "https://example.com/page2" in docs_site_loader.visited_links + + +def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader): + mock_response = Mock() + mock_response.status_code = 404 + mock_requests_get.return_value = mock_response + + docs_site_loader._get_child_links_recursive("https://example.com") + + assert len(docs_site_loader.visited_links) == 0 + + +def test_get_all_urls(mock_requests_get, docs_site_loader): + mock_response = Mock() + mock_response.status_code = 200 + mock_response.text = """ + + Page 1 + Page 2 + External + + """ + mock_requests_get.return_value = mock_response + + all_urls = docs_site_loader._get_all_urls("https://example.com") + + assert len(all_urls) == 3 + assert "https://example.com/page1" in all_urls + assert "https://example.com/page2" in all_urls + assert "https://example.com/external" in all_urls + + +def test_load_data_from_url(mock_requests_get, docs_site_loader): + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = """ + + +
+

Article Content

+
+ + """.encode() + mock_requests_get.return_value = mock_response + + data = docs_site_loader._load_data_from_url("https://example.com/page1") + + assert len(data) == 1 + assert data[0]["content"] == "Article Content" + assert data[0]["meta_data"]["url"] == "https://example.com/page1" + + +def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader): + mock_response = Mock() + mock_response.status_code = 404 + mock_requests_get.return_value = mock_response + + data = docs_site_loader._load_data_from_url("https://example.com/page1") + + assert data == [] + assert len(data) == 0 + + +def test_load_data(mock_requests_get, docs_site_loader): + mock_response = Response() + mock_response.status_code = 200 + mock_response._content = """ + + Page 1 + Page 2 + """.encode() + mock_requests_get.return_value = mock_response + + url = "https://example.com" + data = docs_site_loader.load_data(url) + expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest() + + assert len(data["data"]) == 2 + assert data["doc_id"] == expected_doc_id + + +def test_if_response_status_not_200(mock_requests_get, docs_site_loader): + mock_response = Response() + mock_response.status_code = 404 + mock_requests_get.return_value = mock_response + + url = "https://example.com" + data = docs_site_loader.load_data(url) + expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest() + + assert len(data["data"]) == 0 + assert data["doc_id"] == expected_doc_id diff --git a/tests/loaders/test_docx_file.py b/tests/loaders/test_docx_file.py new file mode 100644 index 00000000..6b3bb193 --- /dev/null +++ b/tests/loaders/test_docx_file.py @@ -0,0 +1,37 @@ +import hashlib +import pytest +from unittest.mock import MagicMock, patch +from embedchain.loaders.docx_file import DocxFileLoader + + +@pytest.fixture +def mock_docx2txt_loader(): + with patch("embedchain.loaders.docx_file.Docx2txtLoader") as mock_loader: + yield mock_loader + + +@pytest.fixture +def docx_file_loader(): + return DocxFileLoader() + + +def test_load_data(mock_docx2txt_loader, docx_file_loader): + mock_url = "mock_docx_file.docx" + + mock_loader = MagicMock() + mock_loader.load.return_value = [MagicMock(page_content="Sample Docx Content", metadata={"url": "local"})] + + mock_docx2txt_loader.return_value = mock_loader + + result = docx_file_loader.load_data(mock_url) + + assert "doc_id" in result + assert "data" in result + + expected_content = "Sample Docx Content" + assert result["data"][0]["content"] == expected_content + + assert result["data"][0]["meta_data"]["url"] == "local" + + expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest() + assert result["doc_id"] == expected_doc_id diff --git a/tests/loaders/test_local_qna_pair.py b/tests/loaders/test_local_qna_pair.py new file mode 100644 index 00000000..29447d19 --- /dev/null +++ b/tests/loaders/test_local_qna_pair.py @@ -0,0 +1,30 @@ +import hashlib +import pytest +from embedchain.loaders.local_qna_pair import LocalQnaPairLoader + + +@pytest.fixture +def qna_pair_loader(): + return LocalQnaPairLoader() + + +def test_load_data(qna_pair_loader): + question = "What is the capital of France?" + answer = "The capital of France is Paris." + + content = (question, answer) + result = qna_pair_loader.load_data(content) + + assert "doc_id" in result + assert "data" in result + url = "local" + + expected_content = f"Q: {question}\nA: {answer}" + assert result["data"][0]["content"] == expected_content + + assert result["data"][0]["meta_data"]["url"] == url + + assert result["data"][0]["meta_data"]["question"] == question + + expected_doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest() + assert result["doc_id"] == expected_doc_id diff --git a/tests/loaders/test_local_text.py b/tests/loaders/test_local_text.py new file mode 100644 index 00000000..7d350ea5 --- /dev/null +++ b/tests/loaders/test_local_text.py @@ -0,0 +1,25 @@ +import hashlib +import pytest +from embedchain.loaders.local_text import LocalTextLoader + + +@pytest.fixture +def text_loader(): + return LocalTextLoader() + + +def test_load_data(text_loader): + mock_content = "This is a sample text content." + + result = text_loader.load_data(mock_content) + + assert "doc_id" in result + assert "data" in result + + url = "local" + assert result["data"][0]["content"] == mock_content + + assert result["data"][0]["meta_data"]["url"] == url + + expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest() + assert result["doc_id"] == expected_doc_id diff --git a/tests/loaders/test_mdx.py b/tests/loaders/test_mdx.py new file mode 100644 index 00000000..960eb6e5 --- /dev/null +++ b/tests/loaders/test_mdx.py @@ -0,0 +1,28 @@ +import hashlib +import pytest +from unittest.mock import patch, mock_open +from embedchain.loaders.mdx import MdxLoader + + +@pytest.fixture +def mdx_loader(): + return MdxLoader() + + +def test_load_data(mdx_loader): + mock_content = "Sample MDX Content" + + # Mock open function to simulate file reading + with patch("builtins.open", mock_open(read_data=mock_content)): + url = "mock_file.mdx" + result = mdx_loader.load_data(url) + + assert "doc_id" in result + assert "data" in result + + assert result["data"][0]["content"] == mock_content + + assert result["data"][0]["meta_data"]["url"] == url + + expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest() + assert result["doc_id"] == expected_doc_id diff --git a/tests/loaders/test_notion.py b/tests/loaders/test_notion.py new file mode 100644 index 00000000..f4d30327 --- /dev/null +++ b/tests/loaders/test_notion.py @@ -0,0 +1,34 @@ +import hashlib +import os +import pytest +from unittest.mock import Mock, patch +from embedchain.loaders.notion import NotionLoader + + +@pytest.fixture +def notion_loader(): + with patch.dict(os.environ, {"NOTION_INTEGRATION_TOKEN": "test_notion_token"}): + yield NotionLoader() + + +def test_load_data(notion_loader): + source = "https://www.notion.so/Test-Page-1234567890abcdef1234567890abcdef" + mock_text = "This is a test page." + expected_doc_id = hashlib.sha256((mock_text + source).encode()).hexdigest() + expected_data = [ + { + "content": mock_text, + "meta_data": {"url": "notion-12345678-90ab-cdef-1234-567890abcdef"}, # formatted_id + } + ] + + mock_page = Mock() + mock_page.text = mock_text + mock_documents = [mock_page] + + with patch("embedchain.loaders.notion.NotionPageReader") as mock_reader: + mock_reader.return_value.load_data.return_value = mock_documents + result = notion_loader.load_data(source) + + assert result["doc_id"] == expected_doc_id + assert result["data"] == expected_data diff --git a/tests/loaders/test_web_page.py b/tests/loaders/test_web_page.py new file mode 100644 index 00000000..61b9031b --- /dev/null +++ b/tests/loaders/test_web_page.py @@ -0,0 +1,115 @@ +import hashlib +import pytest +from unittest.mock import Mock, patch +from embedchain.loaders.web_page import WebPageLoader + + +@pytest.fixture +def web_page_loader(): + return WebPageLoader() + + +def test_load_data(web_page_loader): + page_url = "https://example.com/page" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.content = """ + + + Test Page + + +
+

This is some test content.

+
+ + + """ + with patch("embedchain.loaders.web_page.requests.get", return_value=mock_response): + result = web_page_loader.load_data(page_url) + + content = web_page_loader._get_clean_content(mock_response.content, page_url) + expected_doc_id = hashlib.sha256((content + page_url).encode()).hexdigest() + assert result["doc_id"] == expected_doc_id + + expected_data = [ + { + "content": content, + "meta_data": { + "url": page_url, + }, + } + ] + + assert result["data"] == expected_data + + +def test_get_clean_content_excludes_unnecessary_info(web_page_loader): + mock_html = """ + + + Sample HTML + + + + + + +
Form Content
+
Main Content
+ + + + SVG Content + Canvas Content + + + + + +
Header Sidebar Wrapper Content
+
Blog Sidebar Wrapper Content
+ + + + """ + + tags_to_exclude = [ + "nav", + "aside", + "form", + "header", + "noscript", + "svg", + "canvas", + "footer", + "script", + "style", + ] + ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"] + classes_to_exclude = [ + "elementor-location-header", + "navbar-header", + "nav", + "header-sidebar-wrapper", + "blog-sidebar-wrapper", + "related-posts", + ] + + content = web_page_loader._get_clean_content(mock_html, "https://example.com/page") + + for tag in tags_to_exclude: + assert tag not in content + + for id in ids_to_exclude: + assert id not in content + + for class_name in classes_to_exclude: + assert class_name not in content + + assert len(content) > 0 diff --git a/tests/loaders/test_youtube_video.py b/tests/loaders/test_youtube_video.py new file mode 100644 index 00000000..cc70d779 --- /dev/null +++ b/tests/loaders/test_youtube_video.py @@ -0,0 +1,47 @@ +import hashlib +import pytest +from unittest.mock import MagicMock, Mock, patch +from embedchain.loaders.youtube_video import YoutubeVideoLoader + + +@pytest.fixture +def youtube_video_loader(): + return YoutubeVideoLoader() + + +def test_load_data(youtube_video_loader): + video_url = "https://www.youtube.com/watch?v=VIDEO_ID" + mock_loader = Mock() + mock_page_content = "This is a YouTube video content." + mock_loader.load.return_value = [ + MagicMock( + page_content=mock_page_content, + metadata={"url": video_url, "title": "Test Video"}, + ) + ] + + with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader): + result = youtube_video_loader.load_data(video_url) + + expected_doc_id = hashlib.sha256((mock_page_content + video_url).encode()).hexdigest() + + assert result["doc_id"] == expected_doc_id + + expected_data = [ + { + "content": "This is a YouTube video content.", + "meta_data": {"url": video_url, "title": "Test Video"}, + } + ] + + assert result["data"] == expected_data + + +def test_load_data_with_empty_doc(youtube_video_loader): + video_url = "https://www.youtube.com/watch?v=VIDEO_ID" + mock_loader = Mock() + mock_loader.load.return_value = [] + + with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader): + with pytest.raises(ValueError): + youtube_video_loader.load_data(video_url)