From c0ee68054655aa4c7852bbd6e454e4e3c62723d4 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Fri, 15 Dec 2023 05:59:15 +0530 Subject: [PATCH] [Improvement] Add support for min chunk size (#1007) --- README.md | 1 - docs/api-reference/advanced/configuration.mdx | 2 ++ embedchain/chunkers/base_chunker.py | 9 +++++++-- embedchain/chunkers/images.py | 5 ++++- embedchain/config/add_config.py | 20 ++++++++++++++----- embedchain/embedchain.py | 8 +++++--- pyproject.toml | 2 +- tests/chunkers/test_base_chunker.py | 13 ++++++++++++ tests/chunkers/test_image_chunker.py | 4 ++-- tests/chunkers/test_text.py | 18 ++++++++--------- tests/embedchain/test_add.py | 2 +- 11 files changed, 59 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index c2d6e660..5bc146ca 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,6 @@ elon_bot = App() # Embed online resources elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk") elon_bot.add("https://www.forbes.com/profile/elon-musk") -elon_bot.add("https://www.youtube.com/watch?v=RcYjXbSJBN8") # Query the bot elon_bot.query("How many companies does Elon Musk run and name those?") diff --git a/docs/api-reference/advanced/configuration.mdx b/docs/api-reference/advanced/configuration.mdx index 1fb3e2ff..fa847185 100644 --- a/docs/api-reference/advanced/configuration.mdx +++ b/docs/api-reference/advanced/configuration.mdx @@ -180,6 +180,8 @@ Alright, let's dive into what each key means in the yaml config above: - `chunk_size` (Integer): The size of each chunk of text that is sent to the language model. - `chunk_overlap` (Integer): The amount of overlap between each chunk of text. - `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here. + - `min_chunk_size` (Integer): The minimum size of each chunk of text that is sent to the language model. Must be less than `chunk_size`, and greater than `chunk_overlap`. + If you have questions about the configuration above, please feel free to reach out to us using one of the following methods: \ No newline at end of file diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index 76b38f38..933f97b4 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -1,5 +1,8 @@ import hashlib +import logging +from typing import Optional +from embedchain.config.add_config import ChunkerConfig from embedchain.helpers.json_serializable import JSONSerializable from embedchain.models.data_type import DataType @@ -10,7 +13,7 @@ class BaseChunker(JSONSerializable): self.text_splitter = text_splitter self.data_type = None - def create_chunks(self, loader, src, app_id=None): + def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None): """ Loads data and chunks it. @@ -23,6 +26,8 @@ class BaseChunker(JSONSerializable): documents = [] chunk_ids = [] idMap = {} + min_chunk_size = config.min_chunk_size if config is not None else 1 + logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters") data_result = loader.load_data(src) data_records = data_result["data"] doc_id = data_result["doc_id"] @@ -44,7 +49,7 @@ class BaseChunker(JSONSerializable): for chunk in chunks: chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest() chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id - if idMap.get(chunk_id) is None: + if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size: idMap[chunk_id] = True chunk_ids.append(chunk_id) documents.append(chunk) diff --git a/embedchain/chunkers/images.py b/embedchain/chunkers/images.py index 853e027a..8e0ac03d 100644 --- a/embedchain/chunkers/images.py +++ b/embedchain/chunkers/images.py @@ -1,4 +1,5 @@ import hashlib +import logging from typing import Optional from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -20,7 +21,7 @@ class ImagesChunker(BaseChunker): ) super().__init__(image_splitter) - def create_chunks(self, loader, src, app_id=None): + def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None): """ Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image @@ -32,6 +33,8 @@ class ImagesChunker(BaseChunker): documents = [] embeddings = [] ids = [] + min_chunk_size = config.min_chunk_size if config is not None else 0 + logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters") data_result = loader.load_data(src) data_records = data_result["data"] doc_id = data_result["doc_id"] diff --git a/embedchain/config/add_config.py b/embedchain/config/add_config.py index 66955118..c22960ca 100644 --- a/embedchain/config/add_config.py +++ b/embedchain/config/add_config.py @@ -1,4 +1,5 @@ import builtins +import logging from importlib import import_module from typing import Callable, Optional @@ -14,12 +15,21 @@ class ChunkerConfig(BaseConfig): def __init__( self, - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, + chunk_size: Optional[int] = 2000, + chunk_overlap: Optional[int] = 0, length_function: Optional[Callable[[str], int]] = None, + min_chunk_size: Optional[int] = 0, ): - self.chunk_size = chunk_size if chunk_size else 2000 - self.chunk_overlap = chunk_overlap if chunk_overlap else 0 + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.min_chunk_size = min_chunk_size + if self.min_chunk_size >= self.chunk_size: + raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}") + if self.min_chunk_size <= self.chunk_overlap: + logging.warn( + f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501 + ) + if isinstance(length_function, str): self.length_function = self.load_func(length_function) else: @@ -37,7 +47,7 @@ class ChunkerConfig(BaseConfig): @register_deserializable class LoaderConfig(BaseConfig): """ - Config for the chunker used in `add` method + Config for the loader used in `add` method """ def __init__(self): diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index cc2982c5..0fa4c84b 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -196,7 +196,7 @@ class EmbedChain(JSONSerializable): data_formatter = DataFormatter(data_type, config, loader, chunker) documents, metadatas, _ids, new_chunks = self._load_and_embed( - data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs + data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs ) if data_type in {DataType.DOCS_SITE}: self.is_docs_site_instance = True @@ -339,6 +339,7 @@ class EmbedChain(JSONSerializable): src: Any, metadata: Optional[Dict[str, Any]] = None, source_hash: Optional[str] = None, + add_config: Optional[AddConfig] = None, dry_run=False, **kwargs: Optional[Dict[str, Any]], ): @@ -359,12 +360,13 @@ class EmbedChain(JSONSerializable): app_id = self.config.id if self.config is not None else None # Create chunks - embeddings_data = chunker.create_chunks(loader, src, app_id=app_id) + embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker) # spread chunking results documents = embeddings_data["documents"] metadatas = embeddings_data["metadatas"] ids = embeddings_data["ids"] new_doc_id = embeddings_data["doc_id"] + embeddings = embeddings_data.get("embeddings") if existing_doc_id and existing_doc_id == new_doc_id: print("Doc content has not changed. Skipping creating chunks and embeddings") return [], [], [], 0 @@ -429,7 +431,7 @@ class EmbedChain(JSONSerializable): chunks_before_addition = self.db.count() self.db.add( - embeddings=embeddings_data.get("embeddings", None), + embeddings=embeddings, documents=documents, metadatas=metadatas, ids=ids, diff --git a/pyproject.toml b/pyproject.toml index 91c68462..b3f850af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "embedchain" -version = "0.1.32" +version = "0.1.33" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" authors = [ "Taranjeet Singh ", diff --git a/tests/chunkers/test_base_chunker.py b/tests/chunkers/test_base_chunker.py index 343653ee..23cf1e8c 100644 --- a/tests/chunkers/test_base_chunker.py +++ b/tests/chunkers/test_base_chunker.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock import pytest from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.add_config import ChunkerConfig from embedchain.models.data_type import DataType @@ -35,6 +36,18 @@ def chunker(text_splitter_mock, data_type): return chunker +def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type): + text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"] + loader_mock.load_data.return_value = { + "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}], + "doc_id": "DocID", + } + config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10) + result = chunker.create_chunks(loader_mock, "test_src", app_id, config) + + assert result["documents"] == ["long chunk"] + + def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type): text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"] loader_mock.load_data.return_value = { diff --git a/tests/chunkers/test_image_chunker.py b/tests/chunkers/test_image_chunker.py index eead2862..67f5e563 100644 --- a/tests/chunkers/test_image_chunker.py +++ b/tests/chunkers/test_image_chunker.py @@ -11,7 +11,7 @@ class TestImageChunker(unittest.TestCase): Test the chunks generated by TextChunker. # TODO: Not a very precise test. """ - chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0) chunker = ImagesChunker(config=chunker_config) # Data type must be set manually in the test chunker.set_data_type(DataType.IMAGES) @@ -51,7 +51,7 @@ class TestImageChunker(unittest.TestCase): self.assertEqual(expected_chunks, result) def test_word_count(self): - chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0) chunker = ImagesChunker(config=chunker_config) chunker.set_data_type(DataType.IMAGES) diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py index 9eb73133..e4016cb4 100644 --- a/tests/chunkers/test_text.py +++ b/tests/chunkers/test_text.py @@ -10,12 +10,12 @@ class TestTextChunker: """ Test the chunks generated by TextChunker. """ - chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len) + chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0) chunker = TextChunker(config=chunker_config) text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." # Data type must be set manually in the test chunker.set_data_type(DataType.TEXT) - result = chunker.create_chunks(MockLoader(), text) + result = chunker.create_chunks(MockLoader(), text, chunker_config) documents = result["documents"] assert len(documents) > 5 @@ -23,11 +23,11 @@ class TestTextChunker: """ Test the chunks generated by TextChunker with app_id """ - chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len) + chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0) chunker = TextChunker(config=chunker_config) text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." chunker.set_data_type(DataType.TEXT) - result = chunker.create_chunks(MockLoader(), text) + result = chunker.create_chunks(MockLoader(), text, chunker_config) documents = result["documents"] assert len(documents) > 5 @@ -35,12 +35,12 @@ class TestTextChunker: """ Test that if an infinitely high chunk size is used, only one chunk is returned. """ - chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len) + chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len, min_chunk_size=0) chunker = TextChunker(config=chunker_config) text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." # Data type must be set manually in the test chunker.set_data_type(DataType.TEXT) - result = chunker.create_chunks(MockLoader(), text) + result = chunker.create_chunks(MockLoader(), text, chunker_config) documents = result["documents"] assert len(documents) == 1 @@ -48,18 +48,18 @@ class TestTextChunker: """ Test that if a chunk size of one is used, every character is a chunk. """ - chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0) chunker = TextChunker(config=chunker_config) # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters. text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c""" # Data type must be set manually in the test chunker.set_data_type(DataType.TEXT) - result = chunker.create_chunks(MockLoader(), text) + result = chunker.create_chunks(MockLoader(), text, chunker_config) documents = result["documents"] assert len(documents) == len(text) def test_word_count(self): - chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0) chunker = TextChunker(config=chunker_config) chunker.set_data_type(DataType.TEXT) diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py index 7152f728..b9d8437a 100644 --- a/tests/embedchain/test_add.py +++ b/tests/embedchain/test_add.py @@ -33,7 +33,7 @@ def test_add_forced_type(app): def test_dry_run(app): - chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0) + chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0) text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ""" result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)