Improve tests (#780)

This commit is contained in:
Sidharth Mohanty
2023-10-09 23:56:21 +05:30
committed by GitHub
parent ed02aebf9a
commit b91d922600
12 changed files with 621 additions and 15 deletions

View File

@@ -15,7 +15,25 @@ class WebPageLoader(BaseLoader):
"""Load data from a web page."""
response = requests.get(url)
data = response.content
soup = BeautifulSoup(data, "html.parser")
content = self._get_clean_content(data, url)
meta_data = {
"url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
],
}
def _get_clean_content(self, html, url) -> str:
soup = BeautifulSoup(html, "html.parser")
original_size = len(str(soup.get_text()))
tags_to_exclude = [
@@ -61,17 +79,4 @@ class WebPageLoader(BaseLoader):
f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)" # noqa:E501
)
meta_data = {
"url": url,
}
content = content
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
],
}
return content

View File

@@ -0,0 +1,84 @@
import hashlib
import pytest
from unittest.mock import MagicMock
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.models.data_type import DataType
@pytest.fixture
def text_splitter_mock():
return MagicMock()
@pytest.fixture
def loader_mock():
return MagicMock()
@pytest.fixture
def app_id():
return "test_app"
@pytest.fixture
def data_type():
return DataType.TEXT
@pytest.fixture
def chunker(text_splitter_mock, data_type):
text_splitter = text_splitter_mock
chunker = BaseChunker(text_splitter)
chunker.set_data_type(data_type)
return chunker
def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
loader_mock.load_data.return_value = {
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
"doc_id": "DocID",
}
result = chunker.create_chunks(loader_mock, "test_src", app_id)
expected_ids = [
hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(),
hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(),
]
assert result["documents"] == ["Chunk 1", "Chunk 2"]
assert result["ids"] == expected_ids
assert result["metadatas"] == [
{
"url": "URL 1",
"data_type": data_type.value,
"doc_id": f"{app_id}--DocID",
},
{
"url": "URL 1",
"data_type": data_type.value,
"doc_id": f"{app_id}--DocID",
},
]
assert result["doc_id"] == f"{app_id}--DocID"
def test_get_chunks(chunker, text_splitter_mock):
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
content = "This is a test content."
result = chunker.get_chunks(content)
assert len(result) == 2
assert result == ["Chunk 1", "Chunk 2"]
def test_set_data_type(chunker):
chunker.set_data_type(DataType.MDX)
assert chunker.data_type == DataType.MDX
def test_get_word_count(chunker):
documents = ["This is a test.", "Another test."]
result = chunker.get_word_count(documents)
assert result == 6

View File

@@ -0,0 +1,46 @@
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.sitemap import SitemapChunker
from embedchain.chunkers.table import TableChunker
from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.xml import XmlChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.config.add_config import ChunkerConfig
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
chunker_common_config = {
DocsSiteChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
DocxFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PdfFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
TextChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
MdxChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
NotionChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
QnaPairChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
TableChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
SitemapChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
WebPageChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
}
def test_default_config_values():
for chunker_class, config in chunker_common_config.items():
chunker = chunker_class()
assert chunker.text_splitter._chunk_size == config["chunk_size"]
assert chunker.text_splitter._chunk_overlap == config["chunk_overlap"]
assert chunker.text_splitter._length_function == config["length_function"]
def test_custom_config_values():
for chunker_class, _ in chunker_common_config.items():
chunker = chunker_class(config=chunker_config)
assert chunker.text_splitter._chunk_size == 500
assert chunker.text_splitter._chunk_overlap == 0
assert chunker.text_splitter._length_function == len

View File

@@ -2,6 +2,7 @@ import csv
import os
import pathlib
import tempfile
from unittest.mock import MagicMock, patch
import pytest
@@ -84,3 +85,29 @@ def test_load_data_with_file_uri(delimiter):
# Cleaning up the temporary file
os.unlink(tmpfile.name)
@pytest.mark.parametrize("content", ["ftp://example.com", "sftp://example.com", "mailto://example.com"])
def test_get_file_content(content):
with pytest.raises(ValueError):
loader = CsvLoader()
loader._get_file_content(content)
@pytest.mark.parametrize("content", ["http://example.com", "https://example.com"])
def test_get_file_content_http(content):
"""
Test _get_file_content method of CsvLoader for http and https URLs
"""
with patch("requests.get") as mock_get:
mock_response = MagicMock()
mock_response.text = "Name,Age,Occupation\nAlice,28,Engineer\nBob,35,Doctor\nCharlie,22,Student"
mock_get.return_value = mock_response
loader = CsvLoader()
file_content = loader._get_file_content(content)
mock_get.assert_called_once_with(content)
mock_response.raise_for_status.assert_called_once()
assert file_content.read() == mock_response.text

View File

@@ -0,0 +1,128 @@
import hashlib
import pytest
from unittest.mock import Mock, patch
from requests import Response
from embedchain.loaders.docs_site_loader import DocsSiteLoader
@pytest.fixture
def mock_requests_get():
with patch("requests.get") as mock_get:
yield mock_get
@pytest.fixture
def docs_site_loader():
return DocsSiteLoader()
def test_get_child_links_recursive(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
</html>
"""
mock_requests_get.return_value = mock_response
docs_site_loader._get_child_links_recursive("https://example.com")
assert len(docs_site_loader.visited_links) == 2
assert "https://example.com/page1" in docs_site_loader.visited_links
assert "https://example.com/page2" in docs_site_loader.visited_links
def test_get_child_links_recursive_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
docs_site_loader._get_child_links_recursive("https://example.com")
assert len(docs_site_loader.visited_links) == 0
def test_get_all_urls(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
<a href="https://example.com/external">External</a>
</html>
"""
mock_requests_get.return_value = mock_response
all_urls = docs_site_loader._get_all_urls("https://example.com")
assert len(all_urls) == 3
assert "https://example.com/page1" in all_urls
assert "https://example.com/page2" in all_urls
assert "https://example.com/external" in all_urls
def test_load_data_from_url(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = """
<html>
<nav>
<h1>Navigation</h1>
</nav>
<article class="bd-article">
<p>Article Content</p>
</article>
</html>
""".encode()
mock_requests_get.return_value = mock_response
data = docs_site_loader._load_data_from_url("https://example.com/page1")
assert len(data) == 1
assert data[0]["content"] == "Article Content"
assert data[0]["meta_data"]["url"] == "https://example.com/page1"
def test_load_data_from_url_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Mock()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
data = docs_site_loader._load_data_from_url("https://example.com/page1")
assert data == []
assert len(data) == 0
def test_load_data(mock_requests_get, docs_site_loader):
mock_response = Response()
mock_response.status_code = 200
mock_response._content = """
<html>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
""".encode()
mock_requests_get.return_value = mock_response
url = "https://example.com"
data = docs_site_loader.load_data(url)
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
assert len(data["data"]) == 2
assert data["doc_id"] == expected_doc_id
def test_if_response_status_not_200(mock_requests_get, docs_site_loader):
mock_response = Response()
mock_response.status_code = 404
mock_requests_get.return_value = mock_response
url = "https://example.com"
data = docs_site_loader.load_data(url)
expected_doc_id = hashlib.sha256((" ".join(docs_site_loader.visited_links) + url).encode()).hexdigest()
assert len(data["data"]) == 0
assert data["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,37 @@
import hashlib
import pytest
from unittest.mock import MagicMock, patch
from embedchain.loaders.docx_file import DocxFileLoader
@pytest.fixture
def mock_docx2txt_loader():
with patch("embedchain.loaders.docx_file.Docx2txtLoader") as mock_loader:
yield mock_loader
@pytest.fixture
def docx_file_loader():
return DocxFileLoader()
def test_load_data(mock_docx2txt_loader, docx_file_loader):
mock_url = "mock_docx_file.docx"
mock_loader = MagicMock()
mock_loader.load.return_value = [MagicMock(page_content="Sample Docx Content", metadata={"url": "local"})]
mock_docx2txt_loader.return_value = mock_loader
result = docx_file_loader.load_data(mock_url)
assert "doc_id" in result
assert "data" in result
expected_content = "Sample Docx Content"
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == "local"
expected_doc_id = hashlib.sha256((expected_content + mock_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,30 @@
import hashlib
import pytest
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
@pytest.fixture
def qna_pair_loader():
return LocalQnaPairLoader()
def test_load_data(qna_pair_loader):
question = "What is the capital of France?"
answer = "The capital of France is Paris."
content = (question, answer)
result = qna_pair_loader.load_data(content)
assert "doc_id" in result
assert "data" in result
url = "local"
expected_content = f"Q: {question}\nA: {answer}"
assert result["data"][0]["content"] == expected_content
assert result["data"][0]["meta_data"]["url"] == url
assert result["data"][0]["meta_data"]["question"] == question
expected_doc_id = hashlib.sha256((expected_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,25 @@
import hashlib
import pytest
from embedchain.loaders.local_text import LocalTextLoader
@pytest.fixture
def text_loader():
return LocalTextLoader()
def test_load_data(text_loader):
mock_content = "This is a sample text content."
result = text_loader.load_data(mock_content)
assert "doc_id" in result
assert "data" in result
url = "local"
assert result["data"][0]["content"] == mock_content
assert result["data"][0]["meta_data"]["url"] == url
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

28
tests/loaders/test_mdx.py Normal file
View File

@@ -0,0 +1,28 @@
import hashlib
import pytest
from unittest.mock import patch, mock_open
from embedchain.loaders.mdx import MdxLoader
@pytest.fixture
def mdx_loader():
return MdxLoader()
def test_load_data(mdx_loader):
mock_content = "Sample MDX Content"
# Mock open function to simulate file reading
with patch("builtins.open", mock_open(read_data=mock_content)):
url = "mock_file.mdx"
result = mdx_loader.load_data(url)
assert "doc_id" in result
assert "data" in result
assert result["data"][0]["content"] == mock_content
assert result["data"][0]["meta_data"]["url"] == url
expected_doc_id = hashlib.sha256((mock_content + url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id

View File

@@ -0,0 +1,34 @@
import hashlib
import os
import pytest
from unittest.mock import Mock, patch
from embedchain.loaders.notion import NotionLoader
@pytest.fixture
def notion_loader():
with patch.dict(os.environ, {"NOTION_INTEGRATION_TOKEN": "test_notion_token"}):
yield NotionLoader()
def test_load_data(notion_loader):
source = "https://www.notion.so/Test-Page-1234567890abcdef1234567890abcdef"
mock_text = "This is a test page."
expected_doc_id = hashlib.sha256((mock_text + source).encode()).hexdigest()
expected_data = [
{
"content": mock_text,
"meta_data": {"url": "notion-12345678-90ab-cdef-1234-567890abcdef"}, # formatted_id
}
]
mock_page = Mock()
mock_page.text = mock_text
mock_documents = [mock_page]
with patch("embedchain.loaders.notion.NotionPageReader") as mock_reader:
mock_reader.return_value.load_data.return_value = mock_documents
result = notion_loader.load_data(source)
assert result["doc_id"] == expected_doc_id
assert result["data"] == expected_data

View File

@@ -0,0 +1,115 @@
import hashlib
import pytest
from unittest.mock import Mock, patch
from embedchain.loaders.web_page import WebPageLoader
@pytest.fixture
def web_page_loader():
return WebPageLoader()
def test_load_data(web_page_loader):
page_url = "https://example.com/page"
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = """
<html>
<head>
<title>Test Page</title>
</head>
<body>
<div id="content">
<p>This is some test content.</p>
</div>
</body>
</html>
"""
with patch("embedchain.loaders.web_page.requests.get", return_value=mock_response):
result = web_page_loader.load_data(page_url)
content = web_page_loader._get_clean_content(mock_response.content, page_url)
expected_doc_id = hashlib.sha256((content + page_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
expected_data = [
{
"content": content,
"meta_data": {
"url": page_url,
},
}
]
assert result["data"] == expected_data
def test_get_clean_content_excludes_unnecessary_info(web_page_loader):
mock_html = """
<html>
<head>
<title>Sample HTML</title>
<style>
/* Stylesheet to be excluded */
.elementor-location-header {
background-color: #f0f0f0;
}
</style>
</head>
<body>
<header id="header">Header Content</header>
<nav class="nav">Nav Content</nav>
<aside>Aside Content</aside>
<form>Form Content</form>
<main>Main Content</main>
<footer class="footer">Footer Content</footer>
<script>Some Script</script>
<noscript>NoScript Content</noscript>
<svg>SVG Content</svg>
<canvas>Canvas Content</canvas>
<div id="sidebar">Sidebar Content</div>
<div id="main-navigation">Main Navigation Content</div>
<div id="menu-main-menu">Menu Main Menu Content</div>
<div class="header-sidebar-wrapper">Header Sidebar Wrapper Content</div>
<div class="blog-sidebar-wrapper">Blog Sidebar Wrapper Content</div>
<div class="related-posts">Related Posts Content</div>
</body>
</html>
"""
tags_to_exclude = [
"nav",
"aside",
"form",
"header",
"noscript",
"svg",
"canvas",
"footer",
"script",
"style",
]
ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
classes_to_exclude = [
"elementor-location-header",
"navbar-header",
"nav",
"header-sidebar-wrapper",
"blog-sidebar-wrapper",
"related-posts",
]
content = web_page_loader._get_clean_content(mock_html, "https://example.com/page")
for tag in tags_to_exclude:
assert tag not in content
for id in ids_to_exclude:
assert id not in content
for class_name in classes_to_exclude:
assert class_name not in content
assert len(content) > 0

View File

@@ -0,0 +1,47 @@
import hashlib
import pytest
from unittest.mock import MagicMock, Mock, patch
from embedchain.loaders.youtube_video import YoutubeVideoLoader
@pytest.fixture
def youtube_video_loader():
return YoutubeVideoLoader()
def test_load_data(youtube_video_loader):
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
mock_loader = Mock()
mock_page_content = "This is a YouTube video content."
mock_loader.load.return_value = [
MagicMock(
page_content=mock_page_content,
metadata={"url": video_url, "title": "Test Video"},
)
]
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
result = youtube_video_loader.load_data(video_url)
expected_doc_id = hashlib.sha256((mock_page_content + video_url).encode()).hexdigest()
assert result["doc_id"] == expected_doc_id
expected_data = [
{
"content": "This is a YouTube video content.",
"meta_data": {"url": video_url, "title": "Test Video"},
}
]
assert result["data"] == expected_data
def test_load_data_with_empty_doc(youtube_video_loader):
video_url = "https://www.youtube.com/watch?v=VIDEO_ID"
mock_loader = Mock()
mock_loader.load.return_value = []
with patch("embedchain.loaders.youtube_video.YoutubeLoader.from_youtube_url", return_value=mock_loader):
with pytest.raises(ValueError):
youtube_video_loader.load_data(video_url)