[Improvement] Add support for min chunk size (#1007)
This commit is contained in:
@@ -71,7 +71,6 @@ elon_bot = App()
|
|||||||
# Embed online resources
|
# Embed online resources
|
||||||
elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk")
|
elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk")
|
||||||
elon_bot.add("https://www.forbes.com/profile/elon-musk")
|
elon_bot.add("https://www.forbes.com/profile/elon-musk")
|
||||||
elon_bot.add("https://www.youtube.com/watch?v=RcYjXbSJBN8")
|
|
||||||
|
|
||||||
# Query the bot
|
# Query the bot
|
||||||
elon_bot.query("How many companies does Elon Musk run and name those?")
|
elon_bot.query("How many companies does Elon Musk run and name those?")
|
||||||
|
|||||||
@@ -180,6 +180,8 @@ Alright, let's dive into what each key means in the yaml config above:
|
|||||||
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
|
- `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
|
||||||
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
|
- `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
|
||||||
- `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here.
|
- `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here.
|
||||||
|
- `min_chunk_size` (Integer): The minimum size of each chunk of text that is sent to the language model. Must be less than `chunk_size`, and greater than `chunk_overlap`.
|
||||||
|
|
||||||
If you have questions about the configuration above, please feel free to reach out to us using one of the following methods:
|
If you have questions about the configuration above, please feel free to reach out to us using one of the following methods:
|
||||||
|
|
||||||
<Snippet file="get-help.mdx" />
|
<Snippet file="get-help.mdx" />
|
||||||
@@ -1,5 +1,8 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from embedchain.config.add_config import ChunkerConfig
|
||||||
from embedchain.helpers.json_serializable import JSONSerializable
|
from embedchain.helpers.json_serializable import JSONSerializable
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
|
||||||
@@ -10,7 +13,7 @@ class BaseChunker(JSONSerializable):
|
|||||||
self.text_splitter = text_splitter
|
self.text_splitter = text_splitter
|
||||||
self.data_type = None
|
self.data_type = None
|
||||||
|
|
||||||
def create_chunks(self, loader, src, app_id=None):
|
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||||
"""
|
"""
|
||||||
Loads data and chunks it.
|
Loads data and chunks it.
|
||||||
|
|
||||||
@@ -23,6 +26,8 @@ class BaseChunker(JSONSerializable):
|
|||||||
documents = []
|
documents = []
|
||||||
chunk_ids = []
|
chunk_ids = []
|
||||||
idMap = {}
|
idMap = {}
|
||||||
|
min_chunk_size = config.min_chunk_size if config is not None else 1
|
||||||
|
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
|
||||||
data_result = loader.load_data(src)
|
data_result = loader.load_data(src)
|
||||||
data_records = data_result["data"]
|
data_records = data_result["data"]
|
||||||
doc_id = data_result["doc_id"]
|
doc_id = data_result["doc_id"]
|
||||||
@@ -44,7 +49,7 @@ class BaseChunker(JSONSerializable):
|
|||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
|
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
|
||||||
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
|
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
|
||||||
if idMap.get(chunk_id) is None:
|
if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
|
||||||
idMap[chunk_id] = True
|
idMap[chunk_id] = True
|
||||||
chunk_ids.append(chunk_id)
|
chunk_ids.append(chunk_id)
|
||||||
documents.append(chunk)
|
documents.append(chunk)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
@@ -20,7 +21,7 @@ class ImagesChunker(BaseChunker):
|
|||||||
)
|
)
|
||||||
super().__init__(image_splitter)
|
super().__init__(image_splitter)
|
||||||
|
|
||||||
def create_chunks(self, loader, src, app_id=None):
|
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||||
"""
|
"""
|
||||||
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
||||||
|
|
||||||
@@ -32,6 +33,8 @@ class ImagesChunker(BaseChunker):
|
|||||||
documents = []
|
documents = []
|
||||||
embeddings = []
|
embeddings = []
|
||||||
ids = []
|
ids = []
|
||||||
|
min_chunk_size = config.min_chunk_size if config is not None else 0
|
||||||
|
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
|
||||||
data_result = loader.load_data(src)
|
data_result = loader.load_data(src)
|
||||||
data_records = data_result["data"]
|
data_records = data_result["data"]
|
||||||
doc_id = data_result["doc_id"]
|
doc_id = data_result["doc_id"]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import builtins
|
import builtins
|
||||||
|
import logging
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
@@ -14,12 +15,21 @@ class ChunkerConfig(BaseConfig):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chunk_size: Optional[int] = None,
|
chunk_size: Optional[int] = 2000,
|
||||||
chunk_overlap: Optional[int] = None,
|
chunk_overlap: Optional[int] = 0,
|
||||||
length_function: Optional[Callable[[str], int]] = None,
|
length_function: Optional[Callable[[str], int]] = None,
|
||||||
|
min_chunk_size: Optional[int] = 0,
|
||||||
):
|
):
|
||||||
self.chunk_size = chunk_size if chunk_size else 2000
|
self.chunk_size = chunk_size
|
||||||
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
|
self.chunk_overlap = chunk_overlap
|
||||||
|
self.min_chunk_size = min_chunk_size
|
||||||
|
if self.min_chunk_size >= self.chunk_size:
|
||||||
|
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
|
||||||
|
if self.min_chunk_size <= self.chunk_overlap:
|
||||||
|
logging.warn(
|
||||||
|
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(length_function, str):
|
if isinstance(length_function, str):
|
||||||
self.length_function = self.load_func(length_function)
|
self.length_function = self.load_func(length_function)
|
||||||
else:
|
else:
|
||||||
@@ -37,7 +47,7 @@ class ChunkerConfig(BaseConfig):
|
|||||||
@register_deserializable
|
@register_deserializable
|
||||||
class LoaderConfig(BaseConfig):
|
class LoaderConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Config for the chunker used in `add` method
|
Config for the loader used in `add` method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
|
|
||||||
data_formatter = DataFormatter(data_type, config, loader, chunker)
|
data_formatter = DataFormatter(data_type, config, loader, chunker)
|
||||||
documents, metadatas, _ids, new_chunks = self._load_and_embed(
|
documents, metadatas, _ids, new_chunks = self._load_and_embed(
|
||||||
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
|
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs
|
||||||
)
|
)
|
||||||
if data_type in {DataType.DOCS_SITE}:
|
if data_type in {DataType.DOCS_SITE}:
|
||||||
self.is_docs_site_instance = True
|
self.is_docs_site_instance = True
|
||||||
@@ -339,6 +339,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
src: Any,
|
src: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
source_hash: Optional[str] = None,
|
source_hash: Optional[str] = None,
|
||||||
|
add_config: Optional[AddConfig] = None,
|
||||||
dry_run=False,
|
dry_run=False,
|
||||||
**kwargs: Optional[Dict[str, Any]],
|
**kwargs: Optional[Dict[str, Any]],
|
||||||
):
|
):
|
||||||
@@ -359,12 +360,13 @@ class EmbedChain(JSONSerializable):
|
|||||||
app_id = self.config.id if self.config is not None else None
|
app_id = self.config.id if self.config is not None else None
|
||||||
|
|
||||||
# Create chunks
|
# Create chunks
|
||||||
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
|
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
|
||||||
# spread chunking results
|
# spread chunking results
|
||||||
documents = embeddings_data["documents"]
|
documents = embeddings_data["documents"]
|
||||||
metadatas = embeddings_data["metadatas"]
|
metadatas = embeddings_data["metadatas"]
|
||||||
ids = embeddings_data["ids"]
|
ids = embeddings_data["ids"]
|
||||||
new_doc_id = embeddings_data["doc_id"]
|
new_doc_id = embeddings_data["doc_id"]
|
||||||
|
embeddings = embeddings_data.get("embeddings")
|
||||||
if existing_doc_id and existing_doc_id == new_doc_id:
|
if existing_doc_id and existing_doc_id == new_doc_id:
|
||||||
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
print("Doc content has not changed. Skipping creating chunks and embeddings")
|
||||||
return [], [], [], 0
|
return [], [], [], 0
|
||||||
@@ -429,7 +431,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
chunks_before_addition = self.db.count()
|
chunks_before_addition = self.db.count()
|
||||||
|
|
||||||
self.db.add(
|
self.db.add(
|
||||||
embeddings=embeddings_data.get("embeddings", None),
|
embeddings=embeddings,
|
||||||
documents=documents,
|
documents=documents,
|
||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
ids=ids,
|
ids=ids,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "embedchain"
|
name = "embedchain"
|
||||||
version = "0.1.32"
|
version = "0.1.33"
|
||||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||||
authors = [
|
authors = [
|
||||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
|
from embedchain.config.add_config import ChunkerConfig
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
|
||||||
|
|
||||||
@@ -35,6 +36,18 @@ def chunker(text_splitter_mock, data_type):
|
|||||||
return chunker
|
return chunker
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type):
|
||||||
|
text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"]
|
||||||
|
loader_mock.load_data.return_value = {
|
||||||
|
"data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
|
||||||
|
"doc_id": "DocID",
|
||||||
|
}
|
||||||
|
config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10)
|
||||||
|
result = chunker.create_chunks(loader_mock, "test_src", app_id, config)
|
||||||
|
|
||||||
|
assert result["documents"] == ["long chunk"]
|
||||||
|
|
||||||
|
|
||||||
def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
|
def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
|
||||||
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
|
text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
|
||||||
loader_mock.load_data.return_value = {
|
loader_mock.load_data.return_value = {
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class TestImageChunker(unittest.TestCase):
|
|||||||
Test the chunks generated by TextChunker.
|
Test the chunks generated by TextChunker.
|
||||||
# TODO: Not a very precise test.
|
# TODO: Not a very precise test.
|
||||||
"""
|
"""
|
||||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||||
chunker = ImagesChunker(config=chunker_config)
|
chunker = ImagesChunker(config=chunker_config)
|
||||||
# Data type must be set manually in the test
|
# Data type must be set manually in the test
|
||||||
chunker.set_data_type(DataType.IMAGES)
|
chunker.set_data_type(DataType.IMAGES)
|
||||||
@@ -51,7 +51,7 @@ class TestImageChunker(unittest.TestCase):
|
|||||||
self.assertEqual(expected_chunks, result)
|
self.assertEqual(expected_chunks, result)
|
||||||
|
|
||||||
def test_word_count(self):
|
def test_word_count(self):
|
||||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||||
chunker = ImagesChunker(config=chunker_config)
|
chunker = ImagesChunker(config=chunker_config)
|
||||||
chunker.set_data_type(DataType.IMAGES)
|
chunker.set_data_type(DataType.IMAGES)
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,12 @@ class TestTextChunker:
|
|||||||
"""
|
"""
|
||||||
Test the chunks generated by TextChunker.
|
Test the chunks generated by TextChunker.
|
||||||
"""
|
"""
|
||||||
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||||
chunker = TextChunker(config=chunker_config)
|
chunker = TextChunker(config=chunker_config)
|
||||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
# Data type must be set manually in the test
|
# Data type must be set manually in the test
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
result = chunker.create_chunks(MockLoader(), text)
|
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||||
documents = result["documents"]
|
documents = result["documents"]
|
||||||
assert len(documents) > 5
|
assert len(documents) > 5
|
||||||
|
|
||||||
@@ -23,11 +23,11 @@ class TestTextChunker:
|
|||||||
"""
|
"""
|
||||||
Test the chunks generated by TextChunker with app_id
|
Test the chunks generated by TextChunker with app_id
|
||||||
"""
|
"""
|
||||||
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||||
chunker = TextChunker(config=chunker_config)
|
chunker = TextChunker(config=chunker_config)
|
||||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
result = chunker.create_chunks(MockLoader(), text)
|
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||||
documents = result["documents"]
|
documents = result["documents"]
|
||||||
assert len(documents) > 5
|
assert len(documents) > 5
|
||||||
|
|
||||||
@@ -35,12 +35,12 @@ class TestTextChunker:
|
|||||||
"""
|
"""
|
||||||
Test that if an infinitely high chunk size is used, only one chunk is returned.
|
Test that if an infinitely high chunk size is used, only one chunk is returned.
|
||||||
"""
|
"""
|
||||||
chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||||
chunker = TextChunker(config=chunker_config)
|
chunker = TextChunker(config=chunker_config)
|
||||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
# Data type must be set manually in the test
|
# Data type must be set manually in the test
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
result = chunker.create_chunks(MockLoader(), text)
|
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||||
documents = result["documents"]
|
documents = result["documents"]
|
||||||
assert len(documents) == 1
|
assert len(documents) == 1
|
||||||
|
|
||||||
@@ -48,18 +48,18 @@ class TestTextChunker:
|
|||||||
"""
|
"""
|
||||||
Test that if a chunk size of one is used, every character is a chunk.
|
Test that if a chunk size of one is used, every character is a chunk.
|
||||||
"""
|
"""
|
||||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||||
chunker = TextChunker(config=chunker_config)
|
chunker = TextChunker(config=chunker_config)
|
||||||
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
|
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
|
||||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
|
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
|
||||||
# Data type must be set manually in the test
|
# Data type must be set manually in the test
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
result = chunker.create_chunks(MockLoader(), text)
|
result = chunker.create_chunks(MockLoader(), text, chunker_config)
|
||||||
documents = result["documents"]
|
documents = result["documents"]
|
||||||
assert len(documents) == len(text)
|
assert len(documents) == len(text)
|
||||||
|
|
||||||
def test_word_count(self):
|
def test_word_count(self):
|
||||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
|
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
|
||||||
chunker = TextChunker(config=chunker_config)
|
chunker = TextChunker(config=chunker_config)
|
||||||
chunker.set_data_type(DataType.TEXT)
|
chunker.set_data_type(DataType.TEXT)
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def test_add_forced_type(app):
|
|||||||
|
|
||||||
|
|
||||||
def test_dry_run(app):
|
def test_dry_run(app):
|
||||||
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
|
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0)
|
||||||
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
|
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
|
||||||
|
|
||||||
result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
|
result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user