diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index 5c912b09..547051b5 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -13,7 +13,7 @@ Here's the readme example with configuration options. ```python import os from embedchain import App -from embedchain.config import InitConfig, AddConfig, QueryConfig +from embedchain.config import InitConfig, AddConfig, QueryConfig, ChunkerConfig from chromadb.utils import embedding_functions # Example: use your own embedding function @@ -25,14 +25,8 @@ config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction( naval_chat_bot = App(config) # Example: define your own chunker config for `youtube_video` -youtube_add_config = { - "chunker": { - "chunk_size": 1000, - "chunk_overlap": 100, - "length_function": len, - } -} -naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(**youtube_add_config)) +chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len) +naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config)) add_config = AddConfig() naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config) diff --git a/embedchain/chunkers/docx_file.py b/embedchain/chunkers/docx_file.py index dc5ef305..bccfb589 100644 --- a/embedchain/chunkers/docx_file.py +++ b/embedchain/chunkers/docx_file.py @@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig -TEXT_SPLITTER_CHUNK_PARAMS = { - "chunk_size": 1000, - "chunk_overlap": 0, - "length_function": len, -} - class DocxFileChunker(BaseChunker): """Chunker for .docx file.""" def __init__(self, config: Optional[ChunkerConfig] = None): if config is None: - config = TEXT_SPLITTER_CHUNK_PARAMS - text_splitter = RecursiveCharacterTextSplitter(**config) + config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) super().__init__(text_splitter) diff --git a/embedchain/chunkers/pdf_file.py b/embedchain/chunkers/pdf_file.py index 34ab7607..ec19166b 100644 --- a/embedchain/chunkers/pdf_file.py +++ b/embedchain/chunkers/pdf_file.py @@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig -TEXT_SPLITTER_CHUNK_PARAMS = { - "chunk_size": 1000, - "chunk_overlap": 0, - "length_function": len, -} - class PdfFileChunker(BaseChunker): """Chunker for PDF file.""" def __init__(self, config: Optional[ChunkerConfig] = None): if config is None: - config = TEXT_SPLITTER_CHUNK_PARAMS - text_splitter = RecursiveCharacterTextSplitter(**config) + config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) super().__init__(text_splitter) diff --git a/embedchain/chunkers/qna_pair.py b/embedchain/chunkers/qna_pair.py index 37533c93..ba9d0991 100644 --- a/embedchain/chunkers/qna_pair.py +++ b/embedchain/chunkers/qna_pair.py @@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig -TEXT_SPLITTER_CHUNK_PARAMS = { - "chunk_size": 300, - "chunk_overlap": 0, - "length_function": len, -} - class QnaPairChunker(BaseChunker): """Chunker for QnA pair.""" def __init__(self, config: Optional[ChunkerConfig] = None): if config is None: - config = TEXT_SPLITTER_CHUNK_PARAMS - text_splitter = RecursiveCharacterTextSplitter(**config) + config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) super().__init__(text_splitter) diff --git a/embedchain/chunkers/text.py b/embedchain/chunkers/text.py index 36d07427..44a320d1 100644 --- a/embedchain/chunkers/text.py +++ b/embedchain/chunkers/text.py @@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig -TEXT_SPLITTER_CHUNK_PARAMS = { - "chunk_size": 300, - "chunk_overlap": 0, - "length_function": len, -} - class TextChunker(BaseChunker): """Chunker for text.""" def __init__(self, config: Optional[ChunkerConfig] = None): if config is None: - config = TEXT_SPLITTER_CHUNK_PARAMS - text_splitter = RecursiveCharacterTextSplitter(**config) + config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) super().__init__(text_splitter) diff --git a/embedchain/chunkers/web_page.py b/embedchain/chunkers/web_page.py index 2ec323b5..fd451d8e 100644 --- a/embedchain/chunkers/web_page.py +++ b/embedchain/chunkers/web_page.py @@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig -TEXT_SPLITTER_CHUNK_PARAMS = { - "chunk_size": 500, - "chunk_overlap": 0, - "length_function": len, -} - class WebPageChunker(BaseChunker): """Chunker for web page.""" def __init__(self, config: Optional[ChunkerConfig] = None): if config is None: - config = TEXT_SPLITTER_CHUNK_PARAMS - text_splitter = RecursiveCharacterTextSplitter(**config) + config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) super().__init__(text_splitter) diff --git a/embedchain/chunkers/youtube_video.py b/embedchain/chunkers/youtube_video.py index d63b2e05..4f2ad41f 100644 --- a/embedchain/chunkers/youtube_video.py +++ b/embedchain/chunkers/youtube_video.py @@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig -TEXT_SPLITTER_CHUNK_PARAMS = { - "chunk_size": 2000, - "chunk_overlap": 0, - "length_function": len, -} - class YoutubeVideoChunker(BaseChunker): """Chunker for Youtube video.""" def __init__(self, config: Optional[ChunkerConfig] = None): if config is None: - config = TEXT_SPLITTER_CHUNK_PARAMS - text_splitter = RecursiveCharacterTextSplitter(**config) + config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + length_function=config.length_function, + ) super().__init__(text_splitter) diff --git a/embedchain/config/AddConfig.py b/embedchain/config/AddConfig.py index 5ff2b3da..fe527a1d 100644 --- a/embedchain/config/AddConfig.py +++ b/embedchain/config/AddConfig.py @@ -10,13 +10,13 @@ class ChunkerConfig(BaseConfig): def __init__( self, - chunk_size: Optional[int] = 4000, - chunk_overlap: Optional[int] = 200, - length_function: Optional[Callable[[str], int]] = len, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + length_function: Optional[Callable[[str], int]] = None, ): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - self.length_function = length_function + self.chunk_size = chunk_size if chunk_size else 2000 + self.chunk_overlap = chunk_overlap if chunk_overlap else 0 + self.length_function = length_function if length_function else len class LoaderConfig(BaseConfig): diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index 67d72ff0..8f567a6f 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -1,4 +1,4 @@ -from .AddConfig import AddConfig # noqa: F401 +from .AddConfig import AddConfig, ChunkerConfig # noqa: F401 from .BaseConfig import BaseConfig # noqa: F401 from .ChatConfig import ChatConfig # noqa: F401 from .InitConfig import InitConfig # noqa: F401 diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py index e2aae842..79810c6f 100644 --- a/tests/chunkers/test_text.py +++ b/tests/chunkers/test_text.py @@ -3,6 +3,7 @@ import unittest from embedchain.chunkers.text import TextChunker +from embedchain.config import ChunkerConfig class TestTextChunker(unittest.TestCase): @@ -11,11 +12,7 @@ class TestTextChunker(unittest.TestCase): Test the chunks generated by TextChunker. # TODO: Not a very precise test. """ - chunker_config = { - "chunk_size": 10, - "chunk_overlap": 0, - "length_function": len, - } + chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len) chunker = TextChunker(config=chunker_config) text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."