featL AddConfig should allow configuring Chunker (#200)
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
from typing import Optional
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
|
||||
TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 0,
|
||||
@@ -11,6 +14,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
|
||||
|
||||
class DocxFileChunker(BaseChunker):
|
||||
def __init__(self):
|
||||
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Optional
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
|
||||
|
||||
class PdfFileChunker(BaseChunker):
|
||||
def __init__(self):
|
||||
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
|
||||
super().__init__(text_splitter)
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Optional
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
|
||||
|
||||
class QnaPairChunker(BaseChunker):
|
||||
def __init__(self):
|
||||
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Optional
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
|
||||
|
||||
class TextChunker(BaseChunker):
|
||||
def __init__(self):
|
||||
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Optional
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
|
||||
|
||||
class WebPageChunker(BaseChunker):
|
||||
def __init__(self):
|
||||
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Optional
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
|
||||
|
||||
|
||||
class YoutubeVideoChunker(BaseChunker):
|
||||
def __init__(self):
|
||||
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
|
||||
super().__init__(text_splitter)
|
||||
def __init__(self, config: Optional[ChunkerConfig] = None):
|
||||
if config is None:
|
||||
config = TEXT_SPLITTER_CHUNK_PARAMS
|
||||
text_splitter = RecursiveCharacterTextSplitter(**config)
|
||||
super().__init__(text_splitter)
|
||||
|
||||
@@ -1,8 +1,32 @@
|
||||
from typing import Callable, Optional
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
|
||||
|
||||
class ChunkerConfig(BaseConfig):
|
||||
"""
|
||||
Config for the chunker used in `add` method
|
||||
"""
|
||||
def __init__(self,
|
||||
chunk_size: Optional[int] = 4000,
|
||||
chunk_overlap: Optional[int] = 200,
|
||||
length_function: Optional[Callable[[str], int]] = len):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.length_function = length_function
|
||||
|
||||
class LoaderConfig(BaseConfig):
|
||||
"""
|
||||
Config for the chunker used in `add` method
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class AddConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `add` method.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self,
|
||||
chunker: Optional[ChunkerConfig] = None,
|
||||
loader: Optional[LoaderConfig] = None):
|
||||
self.loader = loader
|
||||
self.chunker = chunker
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
from embedchain.loaders.pdf_file import PdfFileLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
@@ -18,11 +19,11 @@ class DataFormatter:
|
||||
loaders and chunkers to the data_type entered by the user in their
|
||||
.add or .add_local method call
|
||||
"""
|
||||
def __init__(self, data_type):
|
||||
self.loader = self._get_loader(data_type)
|
||||
self.chunker = self._get_chunker(data_type)
|
||||
|
||||
def _get_loader(self, data_type):
|
||||
def __init__(self, data_type: str, config: AddConfig):
|
||||
self.loader = self._get_loader(data_type, config.loader)
|
||||
self.chunker = self._get_chunker(data_type, config.chunker)
|
||||
|
||||
def _get_loader(self, data_type, config):
|
||||
"""
|
||||
Returns the appropriate data loader for the given data type.
|
||||
|
||||
@@ -43,7 +44,7 @@ class DataFormatter:
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
def _get_chunker(self, data_type):
|
||||
def _get_chunker(self, data_type, config):
|
||||
"""
|
||||
Returns the appropriate chunker for the given data type.
|
||||
|
||||
@@ -52,15 +53,14 @@ class DataFormatter:
|
||||
:raises ValueError: If an unsupported data type is provided.
|
||||
"""
|
||||
chunkers = {
|
||||
'youtube_video': YoutubeVideoChunker(),
|
||||
'pdf_file': PdfFileChunker(),
|
||||
'web_page': WebPageChunker(),
|
||||
'qna_pair': QnaPairChunker(),
|
||||
'text': TextChunker(),
|
||||
'docx': DocxFileChunker(),
|
||||
'youtube_video': YoutubeVideoChunker(config),
|
||||
'pdf_file': PdfFileChunker(config),
|
||||
'web_page': WebPageChunker(config),
|
||||
'qna_pair': QnaPairChunker(config),
|
||||
'text': TextChunker(config),
|
||||
'docx': DocxFileChunker(config),
|
||||
}
|
||||
if data_type in chunkers:
|
||||
return chunkers[data_type]
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
|
||||
@@ -37,9 +37,6 @@ class EmbedChain:
|
||||
self.collection = self.config.db.collection
|
||||
self.user_asks = []
|
||||
|
||||
|
||||
|
||||
|
||||
def add(self, data_type, url, config: AddConfig = None):
|
||||
"""
|
||||
Adds the data from the given URL to the vector db.
|
||||
@@ -52,8 +49,8 @@ class EmbedChain:
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
|
||||
data_formatter = DataFormatter(data_type)
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([data_type, url])
|
||||
self.load_and_embed(data_formatter.loader, data_formatter.chunker, url)
|
||||
|
||||
@@ -69,8 +66,8 @@ class EmbedChain:
|
||||
"""
|
||||
if config is None:
|
||||
config = AddConfig()
|
||||
|
||||
data_formatter = DataFormatter(data_type)
|
||||
|
||||
data_formatter = DataFormatter(data_type, config)
|
||||
self.user_asks.append([data_type, content])
|
||||
self.load_and_embed(data_formatter.loader, data_formatter.chunker, content)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user