chore: load chunker from config (#270)
This commit is contained in:
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 0,
|
||||
"length_function": len,
|
||||
}
|
||||
|
||||
|
||||
class DocxFileChunker(BaseChunker):
|
||||
"""Chunker for .docx file."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 0,
|
||||
"length_function": len,
|
||||
}
|
||||
|
||||
|
||||
class PdfFileChunker(BaseChunker):
|
||||
"""Chunker for PDF file."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
"chunk_size": 300,
|
||||
"chunk_overlap": 0,
|
||||
"length_function": len,
|
||||
}
|
||||
|
||||
|
||||
class QnaPairChunker(BaseChunker):
|
||||
"""Chunker for QnA pair."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
"chunk_size": 300,
|
||||
"chunk_overlap": 0,
|
||||
"length_function": len,
|
||||
}
|
||||
|
||||
|
||||
class TextChunker(BaseChunker):
|
||||
"""Chunker for text."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
"chunk_size": 500,
|
||||
"chunk_overlap": 0,
|
||||
"length_function": len,
|
||||
}
|
||||
|
||||
|
||||
class WebPageChunker(BaseChunker):
|
||||
"""Chunker for web page."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
"chunk_size": 2000,
|
||||
"chunk_overlap": 0,
|
||||
"length_function": len,
|
||||
}
|
||||
|
||||
|
||||
class YoutubeVideoChunker(BaseChunker):
|
||||
"""Chunker for Youtube video."""
|
||||
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=config.chunk_size,
|
||||
chunk_overlap=config.chunk_overlap,
|
||||
length_function=config.length_function,
|
||||
)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -10,13 +10,13 @@ class ChunkerConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: Optional[int] = 4000,
|
||||
chunk_overlap: Optional[int] = 200,
|
||||
length_function: Optional[Callable[[str], int]] = len,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
length_function: Optional[Callable[[str], int]] = None,
|
||||
):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.length_function = length_function
|
||||
self.chunk_size = chunk_size if chunk_size else 2000
|
||||
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
|
||||
self.length_function = length_function if length_function else len
|
||||
|
||||
|
||||
class LoaderConfig(BaseConfig):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .AddConfig import AddConfig # noqa: F401
|
||||
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
|
||||
from .BaseConfig import BaseConfig # noqa: F401
|
||||
from .ChatConfig import ChatConfig # noqa: F401
|
||||
from .InitConfig import InitConfig # noqa: F401
|
||||
|
||||
Reference in New Issue
Block a user