[Improvement] Add support for min chunk size (#1007)

This commit is contained in:
Deven Patel
2023-12-15 05:59:15 +05:30
committed by GitHub
parent 9303a1bf81
commit c0ee680546
11 changed files with 59 additions and 25 deletions

View File

@@ -1,5 +1,8 @@
import hashlib
import logging
from typing import Optional
from embedchain.config.add_config import ChunkerConfig
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.models.data_type import DataType
@@ -10,7 +13,7 @@ class BaseChunker(JSONSerializable):
self.text_splitter = text_splitter
self.data_type = None
def create_chunks(self, loader, src, app_id=None):
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
"""
Loads data and chunks it.
@@ -23,6 +26,8 @@ class BaseChunker(JSONSerializable):
documents = []
chunk_ids = []
idMap = {}
min_chunk_size = config.min_chunk_size if config is not None else 1
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
@@ -44,7 +49,7 @@ class BaseChunker(JSONSerializable):
for chunk in chunks:
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
if idMap.get(chunk_id) is None:
if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
idMap[chunk_id] = True
chunk_ids.append(chunk_id)
documents.append(chunk)

View File

@@ -1,4 +1,5 @@
import hashlib
import logging
from typing import Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -20,7 +21,7 @@ class ImagesChunker(BaseChunker):
)
super().__init__(image_splitter)
def create_chunks(self, loader, src, app_id=None):
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
"""
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
@@ -32,6 +33,8 @@ class ImagesChunker(BaseChunker):
documents = []
embeddings = []
ids = []
min_chunk_size = config.min_chunk_size if config is not None else 0
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]

View File

@@ -1,4 +1,5 @@
import builtins
import logging
from importlib import import_module
from typing import Callable, Optional
@@ -14,12 +15,21 @@ class ChunkerConfig(BaseConfig):
def __init__(
self,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
chunk_size: Optional[int] = 2000,
chunk_overlap: Optional[int] = 0,
length_function: Optional[Callable[[str], int]] = None,
min_chunk_size: Optional[int] = 0,
):
self.chunk_size = chunk_size if chunk_size else 2000
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_chunk_size = min_chunk_size
if self.min_chunk_size >= self.chunk_size:
raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
if self.min_chunk_size <= self.chunk_overlap:
logging.warn(
f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant." # noqa:E501
)
if isinstance(length_function, str):
self.length_function = self.load_func(length_function)
else:
@@ -37,7 +47,7 @@ class ChunkerConfig(BaseConfig):
@register_deserializable
class LoaderConfig(BaseConfig):
"""
Config for the chunker used in `add` method
Config for the loader used in `add` method
"""
def __init__(self):

View File

@@ -196,7 +196,7 @@ class EmbedChain(JSONSerializable):
data_formatter = DataFormatter(data_type, config, loader, chunker)
documents, metadatas, _ids, new_chunks = self._load_and_embed(
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs
)
if data_type in {DataType.DOCS_SITE}:
self.is_docs_site_instance = True
@@ -339,6 +339,7 @@ class EmbedChain(JSONSerializable):
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_hash: Optional[str] = None,
add_config: Optional[AddConfig] = None,
dry_run=False,
**kwargs: Optional[Dict[str, Any]],
):
@@ -359,12 +360,13 @@ class EmbedChain(JSONSerializable):
app_id = self.config.id if self.config is not None else None
# Create chunks
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
new_doc_id = embeddings_data["doc_id"]
embeddings = embeddings_data.get("embeddings")
if existing_doc_id and existing_doc_id == new_doc_id:
print("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0
@@ -429,7 +431,7 @@ class EmbedChain(JSONSerializable):
chunks_before_addition = self.db.count()
self.db.add(
embeddings=embeddings_data.get("embeddings", None),
embeddings=embeddings,
documents=documents,
metadatas=metadatas,
ids=ids,