chore: load chunker from config (#270)
This commit is contained in:
@@ -13,7 +13,7 @@ Here's the readme example with configuration options.
|
|||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from embedchain import App
|
from embedchain import App
|
||||||
from embedchain.config import InitConfig, AddConfig, QueryConfig
|
from embedchain.config import InitConfig, AddConfig, QueryConfig, ChunkerConfig
|
||||||
from chromadb.utils import embedding_functions
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
# Example: use your own embedding function
|
# Example: use your own embedding function
|
||||||
@@ -25,14 +25,8 @@ config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
|
|||||||
naval_chat_bot = App(config)
|
naval_chat_bot = App(config)
|
||||||
|
|
||||||
# Example: define your own chunker config for `youtube_video`
|
# Example: define your own chunker config for `youtube_video`
|
||||||
youtube_add_config = {
|
chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len)
|
||||||
"chunker": {
|
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config))
|
||||||
"chunk_size": 1000,
|
|
||||||
"chunk_overlap": 100,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(**youtube_add_config))
|
|
||||||
|
|
||||||
add_config = AddConfig()
|
add_config = AddConfig()
|
||||||
naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
|
naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
|
||||||
|
|||||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
|
||||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
|
||||||
"chunk_size": 1000,
|
|
||||||
"chunk_overlap": 0,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DocxFileChunker(BaseChunker):
|
class DocxFileChunker(BaseChunker):
|
||||||
"""Chunker for .docx file."""
|
"""Chunker for .docx file."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
super().__init__(text_splitter)
|
super().__init__(text_splitter)
|
||||||
|
|||||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
|
||||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
|
||||||
"chunk_size": 1000,
|
|
||||||
"chunk_overlap": 0,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PdfFileChunker(BaseChunker):
|
class PdfFileChunker(BaseChunker):
|
||||||
"""Chunker for PDF file."""
|
"""Chunker for PDF file."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
super().__init__(text_splitter)
|
super().__init__(text_splitter)
|
||||||
|
|||||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
|
||||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
|
||||||
"chunk_size": 300,
|
|
||||||
"chunk_overlap": 0,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class QnaPairChunker(BaseChunker):
|
class QnaPairChunker(BaseChunker):
|
||||||
"""Chunker for QnA pair."""
|
"""Chunker for QnA pair."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
super().__init__(text_splitter)
|
super().__init__(text_splitter)
|
||||||
|
|||||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
|
||||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
|
||||||
"chunk_size": 300,
|
|
||||||
"chunk_overlap": 0,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TextChunker(BaseChunker):
|
class TextChunker(BaseChunker):
|
||||||
"""Chunker for text."""
|
"""Chunker for text."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
super().__init__(text_splitter)
|
super().__init__(text_splitter)
|
||||||
|
|||||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
|
||||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
|
||||||
"chunk_size": 500,
|
|
||||||
"chunk_overlap": 0,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WebPageChunker(BaseChunker):
|
class WebPageChunker(BaseChunker):
|
||||||
"""Chunker for web page."""
|
"""Chunker for web page."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
super().__init__(text_splitter)
|
super().__init__(text_splitter)
|
||||||
|
|||||||
@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
|
||||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
|
||||||
"chunk_size": 2000,
|
|
||||||
"chunk_overlap": 0,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class YoutubeVideoChunker(BaseChunker):
|
class YoutubeVideoChunker(BaseChunker):
|
||||||
"""Chunker for Youtube video."""
|
"""Chunker for Youtube video."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||||
if config is None:
|
if config is None:
|
||||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
|
||||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=config.chunk_size,
|
||||||
|
chunk_overlap=config.chunk_overlap,
|
||||||
|
length_function=config.length_function,
|
||||||
|
)
|
||||||
super().__init__(text_splitter)
|
super().__init__(text_splitter)
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ class ChunkerConfig(BaseConfig):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chunk_size: Optional[int] = 4000,
|
chunk_size: Optional[int] = None,
|
||||||
chunk_overlap: Optional[int] = 200,
|
chunk_overlap: Optional[int] = None,
|
||||||
length_function: Optional[Callable[[str], int]] = len,
|
length_function: Optional[Callable[[str], int]] = None,
|
||||||
):
|
):
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size if chunk_size else 2000
|
||||||
self.chunk_overlap = chunk_overlap
|
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
|
||||||
self.length_function = length_function
|
self.length_function = length_function if length_function else len
|
||||||
|
|
||||||
|
|
||||||
class LoaderConfig(BaseConfig):
|
class LoaderConfig(BaseConfig):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .AddConfig import AddConfig # noqa: F401
|
from .AddConfig import AddConfig, ChunkerConfig # noqa: F401
|
||||||
from .BaseConfig import BaseConfig # noqa: F401
|
from .BaseConfig import BaseConfig # noqa: F401
|
||||||
from .ChatConfig import ChatConfig # noqa: F401
|
from .ChatConfig import ChatConfig # noqa: F401
|
||||||
from .InitConfig import InitConfig # noqa: F401
|
from .InitConfig import InitConfig # noqa: F401
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from embedchain.chunkers.text import TextChunker
|
from embedchain.chunkers.text import TextChunker
|
||||||
|
from embedchain.config import ChunkerConfig
|
||||||
|
|
||||||
|
|
||||||
class TestTextChunker(unittest.TestCase):
|
class TestTextChunker(unittest.TestCase):
|
||||||
@@ -11,11 +12,7 @@ class TestTextChunker(unittest.TestCase):
|
|||||||
Test the chunks generated by TextChunker.
|
Test the chunks generated by TextChunker.
|
||||||
# TODO: Not a very precise test.
|
# TODO: Not a very precise test.
|
||||||
"""
|
"""
|
||||||
chunker_config = {
|
chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
|
||||||
"chunk_size": 10,
|
|
||||||
"chunk_overlap": 0,
|
|
||||||
"length_function": len,
|
|
||||||
}
|
|
||||||
chunker = TextChunker(config=chunker_config)
|
chunker = TextChunker(config=chunker_config)
|
||||||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user