Improve tests (#780)

This commit is contained in:
Sidharth Mohanty
2023-10-09 23:56:21 +05:30
committed by GitHub
parent ed02aebf9a
commit b91d922600
12 changed files with 621 additions and 15 deletions

View File

@@ -0,0 +1,84 @@
import hashlib
import pytest
from unittest.mock import MagicMock
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.models.data_type import DataType
@pytest.fixture
def text_splitter_mock():
return MagicMock()
@pytest.fixture
def loader_mock():
return MagicMock()
@pytest.fixture
def app_id():
return "test_app"
@pytest.fixture
def data_type():
return DataType.TEXT
@pytest.fixture
def chunker(text_splitter_mock, data_type):
text_splitter = text_splitter_mock
chunker = BaseChunker(text_splitter)
chunker.set_data_type(data_type)
return chunker
def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
loader_mock.load_data.return_value = {
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
"doc_id": "DocID",
}
result = chunker.create_chunks(loader_mock, "test_src", app_id)
expected_ids = [
hashlib.sha256(("Chunk 1" + "URL 1").encode()).hexdigest(),
hashlib.sha256(("Chunk 2" + "URL 1").encode()).hexdigest(),
]
assert result["documents"] == ["Chunk 1", "Chunk 2"]
assert result["ids"] == expected_ids
assert result["metadatas"] == [
{
"url": "URL 1",
"data_type": data_type.value,
"doc_id": f"{app_id}--DocID",
},
{
"url": "URL 1",
"data_type": data_type.value,
"doc_id": f"{app_id}--DocID",
},
]
assert result["doc_id"] == f"{app_id}--DocID"
def test_get_chunks(chunker, text_splitter_mock):
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
content = "This is a test content."
result = chunker.get_chunks(content)
assert len(result) == 2
assert result == ["Chunk 1", "Chunk 2"]
def test_set_data_type(chunker):
chunker.set_data_type(DataType.MDX)
assert chunker.data_type == DataType.MDX
def test_get_word_count(chunker):
documents = ["This is a test.", "Another test."]
result = chunker.get_word_count(documents)
assert result == 6

View File

@@ -0,0 +1,46 @@
from embedchain.chunkers.docs_site import DocsSiteChunker
from embedchain.chunkers.docx_file import DocxFileChunker
from embedchain.chunkers.mdx import MdxChunker
from embedchain.chunkers.notion import NotionChunker
from embedchain.chunkers.pdf_file import PdfFileChunker
from embedchain.chunkers.qna_pair import QnaPairChunker
from embedchain.chunkers.sitemap import SitemapChunker
from embedchain.chunkers.table import TableChunker
from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.xml import XmlChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.config.add_config import ChunkerConfig
chunker_config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
chunker_common_config = {
DocsSiteChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
DocxFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
PdfFileChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
TextChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
MdxChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
NotionChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
QnaPairChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
TableChunker: {"chunk_size": 300, "chunk_overlap": 0, "length_function": len},
SitemapChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
WebPageChunker: {"chunk_size": 500, "chunk_overlap": 0, "length_function": len},
XmlChunker: {"chunk_size": 500, "chunk_overlap": 50, "length_function": len},
YoutubeVideoChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
}
def test_default_config_values():
for chunker_class, config in chunker_common_config.items():
chunker = chunker_class()
assert chunker.text_splitter._chunk_size == config["chunk_size"]
assert chunker.text_splitter._chunk_overlap == config["chunk_overlap"]
assert chunker.text_splitter._length_function == config["length_function"]
def test_custom_config_values():
for chunker_class, _ in chunker_common_config.items():
chunker = chunker_class(config=chunker_config)
assert chunker.text_splitter._chunk_size == 500
assert chunker.text_splitter._chunk_overlap == 0
assert chunker.text_splitter._length_function == len