[Improvement] Add support for min chunk size (#1007)

This commit is contained in:
Deven Patel
2023-12-15 05:59:15 +05:30
committed by GitHub
parent 9303a1bf81
commit c0ee680546
11 changed files with 59 additions and 25 deletions

View File

@@ -71,7 +71,6 @@ elon_bot = App()
# Embed online resources # Embed online resources
elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk") elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk")
elon_bot.add("https://www.forbes.com/profile/elon-musk") elon_bot.add("https://www.forbes.com/profile/elon-musk")
elon_bot.add("https://www.youtube.com/watch?v=RcYjXbSJBN8")
# Query the bot # Query the bot
elon_bot.query("How many companies does Elon Musk run and name those?") elon_bot.query("How many companies does Elon Musk run and name those?")

View File

@@ -180,6 +180,8 @@ Alright, let's dive into what each key means in the yaml config above:
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model. - `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text. - `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
- `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here. - `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here.
- `min_chunk_size` (Integer): The minimum size of each chunk of text that is sent to the language model. Must be less than `chunk_size`, and greater than `chunk_overlap`.
If you have questions about the configuration above, please feel free to reach out to us using one of the following methods: If you have questions about the configuration above, please feel free to reach out to us using one of the following methods:
<Snippet file="get-help.mdx" /> <Snippet file="get-help.mdx" />

View File

@@ -1,5 +1,8 @@
import hashlib import hashlib
import logging
from typing import Optional
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
@@ -10,7 +13,7 @@ class BaseChunker(JSONSerializable):
self.text_splitter = text_splitter self.text_splitter = text_splitter
self.data_type = None self.data_type = None
def create_chunks(self, loader, src, app_id=None): def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
""" """
Loads data and chunks it. Loads data and chunks it.
@@ -23,6 +26,8 @@ class BaseChunker(JSONSerializable):
documents = [] documents = []
chunk_ids = [] chunk_ids = []
idMap = {} idMap = {}
min_chunk_size = config.min_chunk_size if config is not None else 1
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src) data_result = loader.load_data(src)
data_records = data_result["data"] data_records = data_result["data"]
doc_id = data_result["doc_id"] doc_id = data_result["doc_id"]
@@ -44,7 +49,7 @@ class BaseChunker(JSONSerializable):
for chunk in chunks: for chunk in chunks:
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest() chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
if idMap.get(chunk_id) is None: if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
idMap[chunk_id] = True idMap[chunk_id] = True
chunk_ids.append(chunk_id) chunk_ids.append(chunk_id)
documents.append(chunk) documents.append(chunk)

View File

@@ -1,4 +1,5 @@
import hashlib import hashlib
import logging
from typing import Optional from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -20,7 +21,7 @@ class ImagesChunker(BaseChunker):
) )
super().__init__(image_splitter) super().__init__(image_splitter)
def create_chunks(self, loader, src, app_id=None): def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
""" """
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
@@ -32,6 +33,8 @@ class ImagesChunker(BaseChunker):
documents = [] documents = []
embeddings = [] embeddings = []
ids = [] ids = []
min_chunk_size = config.min_chunk_size if config is not None else 0
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src) data_result = loader.load_data(src)
data_records = data_result["data"] data_records = data_result["data"]
doc_id = data_result["doc_id"] doc_id = data_result["doc_id"]

View File

@@ -1,4 +1,5 @@
import builtins import builtins
import logging
from importlib import import_module from importlib import import_module
from typing import Callable, Optional from typing import Callable, Optional
@@ -14,12 +15,21 @@ class ChunkerConfig(BaseConfig):
def __init__( def __init__(
self, self,
chunk_size: Optional[int] = None, chunk_size: Optional[int] = 2000,
chunk_overlap: Optional[int] = None, chunk_overlap: Optional[int] = 0,
length_function: Optional[Callable[[str], int]] = None, length_function: Optional[Callable[[str], int]] = None,
min_chunk_size: Optional[int] = 0,
): ):
self.chunk_size = chunk_size if chunk_size else 2000 self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap if chunk_overlap else 0 self.chunk_overlap = chunk_overlap
self.min_chunk_size = min_chunk_size
if self.min_chunk_size >= self.chunk_size:
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
if self.min_chunk_size <= self.chunk_overlap:
logging.warn(
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
)
if isinstance(length_function, str): if isinstance(length_function, str):
self.length_function = self.load_func(length_function) self.length_function = self.load_func(length_function)
else: else:
@@ -37,7 +47,7 @@ class ChunkerConfig(BaseConfig):
@register_deserializable @register_deserializable
class LoaderConfig(BaseConfig): class LoaderConfig(BaseConfig):
""" """
Config for the chunker used in `add` method Config for the loader used in `add` method
""" """
def __init__(self): def __init__(self):

View File

@@ -196,7 +196,7 @@ class EmbedChain(JSONSerializable):
data_formatter = DataFormatter(data_type, config, loader, chunker) data_formatter = DataFormatter(data_type, config, loader, chunker)
documents, metadatas, _ids, new_chunks = self._load_and_embed( documents, metadatas, _ids, new_chunks = self._load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs
) )
if data_type in {DataType.DOCS_SITE}: if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True self.is_docs_site_instance = True
@@ -339,6 +339,7 @@ class EmbedChain(JSONSerializable):
src: Any, src: Any,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
source_hash: Optional[str] = None, source_hash: Optional[str] = None,
add_config: Optional[AddConfig] = None,
dry_run=False, dry_run=False,
**kwargs: Optional[Dict[str, Any]], **kwargs: Optional[Dict[str, Any]],
): ):
@@ -359,12 +360,13 @@ class EmbedChain(JSONSerializable):
app_id = self.config.id if self.config is not None else None app_id = self.config.id if self.config is not None else None
# Create chunks # Create chunks
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id) embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
# spread chunking results # spread chunking results
documents = embeddings_data["documents"] documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"] metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"] ids = embeddings_data["ids"]
new_doc_id = embeddings_data["doc_id"] new_doc_id = embeddings_data["doc_id"]
embeddings = embeddings_data.get("embeddings")
if existing_doc_id and existing_doc_id == new_doc_id: if existing_doc_id and existing_doc_id == new_doc_id:
print("Doc content has not changed. Skipping creating chunks and embeddings") print("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0 return [], [], [], 0
@@ -429,7 +431,7 @@ class EmbedChain(JSONSerializable):
chunks_before_addition = self.db.count() chunks_before_addition = self.db.count()
self.db.add( self.db.add(
embeddings=embeddings_data.get("embeddings", None), embeddings=embeddings,
documents=documents, documents=documents,
metadatas=metadatas, metadatas=metadatas,
ids=ids, ids=ids,

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.32" version = "0.1.33"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType
@@ -35,6 +36,18 @@ def chunker(text_splitter_mock, data_type):
return chunker return chunker
def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type):
text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"]
loader_mock.load_data.return_value = {
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
"doc_id": "DocID",
}
config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10)
result = chunker.create_chunks(loader_mock, "test_src", app_id, config)
assert result["documents"] == ["long chunk"]
def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type): def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"] text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
loader_mock.load_data.return_value = { loader_mock.load_data.return_value = {

View File

@@ -11,7 +11,7 @@ class TestImageChunker(unittest.TestCase):
Test the chunks generated by TextChunker. Test the chunks generated by TextChunker.
# TODO: Not a very precise test. # TODO: Not a very precise test.
""" """
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = ImagesChunker(config=chunker_config) chunker = ImagesChunker(config=chunker_config)
# Data type must be set manually in the test # Data type must be set manually in the test
chunker.set_data_type(DataType.IMAGES) chunker.set_data_type(DataType.IMAGES)
@@ -51,7 +51,7 @@ class TestImageChunker(unittest.TestCase):
self.assertEqual(expected_chunks, result) self.assertEqual(expected_chunks, result)
def test_word_count(self): def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = ImagesChunker(config=chunker_config) chunker = ImagesChunker(config=chunker_config)
chunker.set_data_type(DataType.IMAGES) chunker.set_data_type(DataType.IMAGES)

View File

@@ -10,12 +10,12 @@ class TestTextChunker:
""" """
Test the chunks generated by TextChunker. Test the chunks generated by TextChunker.
""" """
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config) chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test # Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT) chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text) result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"] documents = result["documents"]
assert len(documents) > 5 assert len(documents) > 5
@@ -23,11 +23,11 @@ class TestTextChunker:
""" """
Test the chunks generated by TextChunker with app_id Test the chunks generated by TextChunker with app_id
""" """
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config) chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
chunker.set_data_type(DataType.TEXT) chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text) result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"] documents = result["documents"]
assert len(documents) > 5 assert len(documents) > 5
@@ -35,12 +35,12 @@ class TestTextChunker:
""" """
Test that if an infinitely high chunk size is used, only one chunk is returned. Test that if an infinitely high chunk size is used, only one chunk is returned.
""" """
chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config) chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
# Data type must be set manually in the test # Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT) chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text) result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"] documents = result["documents"]
assert len(documents) == 1 assert len(documents) == 1
@@ -48,18 +48,18 @@ class TestTextChunker:
""" """
Test that if a chunk size of one is used, every character is a chunk. Test that if a chunk size of one is used, every character is a chunk.
""" """
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config) chunker = TextChunker(config=chunker_config)
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters. # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c""" text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
# Data type must be set manually in the test # Data type must be set manually in the test
chunker.set_data_type(DataType.TEXT) chunker.set_data_type(DataType.TEXT)
result = chunker.create_chunks(MockLoader(), text) result = chunker.create_chunks(MockLoader(), text, chunker_config)
documents = result["documents"] documents = result["documents"]
assert len(documents) == len(text) assert len(documents) == len(text)
def test_word_count(self): def test_word_count(self):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len) chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
chunker = TextChunker(config=chunker_config) chunker = TextChunker(config=chunker_config)
chunker.set_data_type(DataType.TEXT) chunker.set_data_type(DataType.TEXT)

View File

@@ -33,7 +33,7 @@ def test_add_forced_type(app):
def test_dry_run(app): def test_dry_run(app):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0) chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0)
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ""" text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True) result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)