[Improvement] Add support for min chunk size (#1007)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
@@ -10,7 +13,7 @@ class BaseChunker(JSONSerializable):
|
||||
self.text_splitter = text_splitter
|
||||
self.data_type = None
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None):
|
||||
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||
"""
|
||||
Loads data and chunks it.
|
||||
|
||||
@@ -23,6 +26,8 @@ class BaseChunker(JSONSerializable):
|
||||
documents = []
|
||||
chunk_ids = []
|
||||
idMap = {}
|
||||
min_chunk_size = config.min_chunk_size if config is not None else 1
|
||||
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
@@ -44,7 +49,7 @@ class BaseChunker(JSONSerializable):
|
||||
for chunk in chunks:
|
||||
chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
|
||||
chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
|
||||
if idMap.get(chunk_id) is None:
|
||||
if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
|
||||
idMap[chunk_id] = True
|
||||
chunk_ids.append(chunk_id)
|
||||
documents.append(chunk)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
@@ -20,7 +21,7 @@ class ImagesChunker(BaseChunker):
|
||||
)
|
||||
super().__init__(image_splitter)
|
||||
|
||||
def create_chunks(self, loader, src, app_id=None):
|
||||
def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
|
||||
"""
|
||||
Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
|
||||
|
||||
@@ -32,6 +33,8 @@ class ImagesChunker(BaseChunker):
|
||||
documents = []
|
||||
embeddings = []
|
||||
ids = []
|
||||
min_chunk_size = config.min_chunk_size if config is not None else 0
|
||||
logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
|
||||
data_result = loader.load_data(src)
|
||||
data_records = data_result["data"]
|
||||
doc_id = data_result["doc_id"]
|
||||
|
||||
Reference in New Issue
Block a user