featL AddConfig should allow configuring Chunker (#200)

This commit is contained in:
Anupam Singh
2023-07-11 04:23:56 +05:30
committed by GitHub
parent ae87dc4a6d
commit eda28cc491
10 changed files with 120 additions and 39 deletions

View File

@@ -1,8 +1,11 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
TEXT_SPLITTER_CHUNK_PARAMS = {
"chunk_size": 1000,
"chunk_overlap": 0,
@@ -11,6 +14,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class DocxFileChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class PdfFileChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
super().__init__(text_splitter)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class QnaPairChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class TextChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class WebPageChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class YoutubeVideoChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
super().__init__(text_splitter)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,8 +1,32 @@
from typing import Callable, Optional
from embedchain.config.BaseConfig import BaseConfig
class ChunkerConfig(BaseConfig):
"""
Config for the chunker used in `add` method
"""
def __init__(self,
chunk_size: Optional[int] = 4000,
chunk_overlap: Optional[int] = 200,
length_function: Optional[Callable[[str], int]] = len):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
class LoaderConfig(BaseConfig):
"""
Config for the chunker used in `add` method
"""
def __init__(self):
pass
class AddConfig(BaseConfig):
"""
Config for the `add` method.
"""
def __init__(self):
pass
def __init__(self,
chunker: Optional[ChunkerConfig] = None,
loader: Optional[LoaderConfig] = None):
self.loader = loader
self.chunker = chunker

View File

@@ -1,3 +1,4 @@
from embedchain.config import AddConfig
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.web_page import WebPageLoader
@@ -18,11 +19,11 @@ class DataFormatter:
loaders and chunkers to the data_type entered by the user in their
.add or .add_local method call
"""
def __init__(self, data_type):
self.loader = self._get_loader(data_type)
self.chunker = self._get_chunker(data_type)
def _get_loader(self, data_type):
def __init__(self, data_type: str, config: AddConfig):
self.loader = self._get_loader(data_type, config.loader)
self.chunker = self._get_chunker(data_type, config.chunker)
def _get_loader(self, data_type, config):
"""
Returns the appropriate data loader for the given data type.
@@ -43,7 +44,7 @@ class DataFormatter:
else:
raise ValueError(f"Unsupported data type: {data_type}")
def _get_chunker(self, data_type):
def _get_chunker(self, data_type, config):
"""
Returns the appropriate chunker for the given data type.
@@ -52,15 +53,14 @@ class DataFormatter:
:raises ValueError: If an unsupported data type is provided.
"""
chunkers = {
'youtube_video': YoutubeVideoChunker(),
'pdf_file': PdfFileChunker(),
'web_page': WebPageChunker(),
'qna_pair': QnaPairChunker(),
'text': TextChunker(),
'docx': DocxFileChunker(),
'youtube_video': YoutubeVideoChunker(config),
'pdf_file': PdfFileChunker(config),
'web_page': WebPageChunker(config),
'qna_pair': QnaPairChunker(config),
'text': TextChunker(config),
'docx': DocxFileChunker(config),
}
if data_type in chunkers:
return chunkers[data_type]
else:
raise ValueError(f"Unsupported data type: {data_type}")

View File

@@ -37,9 +37,6 @@ class EmbedChain:
self.collection = self.config.db.collection
self.user_asks = []
def add(self, data_type, url, config: AddConfig = None):
"""
Adds the data from the given URL to the vector db.
@@ -52,8 +49,8 @@ class EmbedChain:
"""
if config is None:
config = AddConfig()
data_formatter = DataFormatter(data_type)
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, url])
self.load_and_embed(data_formatter.loader, data_formatter.chunker, url)
@@ -69,8 +66,8 @@ class EmbedChain:
"""
if config is None:
config = AddConfig()
data_formatter = DataFormatter(data_type)
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, content])
self.load_and_embed(data_formatter.loader, data_formatter.chunker, content)