featL AddConfig should allow configuring Chunker (#200)

This commit is contained in:
Anupam Singh
2023-07-11 04:23:56 +05:30
committed by GitHub
parent ae87dc4a6d
commit eda28cc491
10 changed files with 120 additions and 39 deletions

View File

@@ -377,8 +377,17 @@ config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
))
naval_chat_bot = App(config)
add_config = AddConfig() # Currently no options
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", add_config)
# Example: define your own chunker config for `youtube_video`
youtube_add_config = {
"chunker": {
"chunk_size": 1000,
"chunk_overlap": 100,
"length_function": len,
}
}
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(**youtube_add_config))
add_config = AddConfig()
naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
naval_chat_bot.add("web_page", "https://nav.al/feedback", add_config)
naval_chat_bot.add("web_page", "https://nav.al/agi", add_config)
@@ -450,13 +459,39 @@ This section describes all possible config options.
#### **Add Config**
|option|description|type|default|
|---|---|---|---|
|chunker|chunker config|ChunkerConfig|Default values for chunker depends on the `data_type`. Please refer [ChunkerConfig](#chunker-config)|
|loader|loader config|LoaderConfig|None|
##### **Chunker Config**
|option|description|type|default|
|---|---|---|---|
|chunk_size|Maximum size of chunks to return|int|Default value for various `data_type` mentioned below|
|chunk_overlap|Overlap in characters between chunks|int|Default value for various `data_type` mentioned below|
|length_function|Function that measures the length of given chunks|typing.Callable|Default value for various `data_type` mentioned below|
Default values of chunker config parameters for different `data_type`:
|data_type|chunk_size|chunk_overlap|length_function|
|---|---|---|---|
|docx|1000|0|len|
|text|300|0|len|
|qna_pair|300|0|len|
|web_page|500|0|len|
|pdf_file|1000|0|len|
|youtube_video|2000|0|len|
##### **Loader Config**
_coming soon_
#### **Query Config**
|option|description|type|default|
|---|---|---|---|
|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: $query Helpful Answer:")|
|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: \$query Helpful Answer:")|
|history|include conversation history from your client or database|any (recommendation: list[str])|None
|stream|control if response is streamed back to the user|bool|False|

View File

@@ -1,8 +1,11 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
TEXT_SPLITTER_CHUNK_PARAMS = {
"chunk_size": 1000,
"chunk_overlap": 0,
@@ -11,6 +14,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class DocxFileChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class PdfFileChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
super().__init__(text_splitter)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class QnaPairChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class TextChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class WebPageChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,4 +1,6 @@
from typing import Optional
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
class YoutubeVideoChunker(BaseChunker):
def __init__(self):
text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
super().__init__(text_splitter)
def __init__(self, config: Optional[ChunkerConfig] = None):
if config is None:
config = TEXT_SPLITTER_CHUNK_PARAMS
text_splitter = RecursiveCharacterTextSplitter(**config)
super().__init__(text_splitter)

View File

@@ -1,8 +1,32 @@
from typing import Callable, Optional
from embedchain.config.BaseConfig import BaseConfig
class ChunkerConfig(BaseConfig):
"""
Config for the chunker used in `add` method
"""
def __init__(self,
chunk_size: Optional[int] = 4000,
chunk_overlap: Optional[int] = 200,
length_function: Optional[Callable[[str], int]] = len):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
class LoaderConfig(BaseConfig):
"""
Config for the chunker used in `add` method
"""
def __init__(self):
pass
class AddConfig(BaseConfig):
"""
Config for the `add` method.
"""
def __init__(self):
pass
def __init__(self,
chunker: Optional[ChunkerConfig] = None,
loader: Optional[LoaderConfig] = None):
self.loader = loader
self.chunker = chunker

View File

@@ -1,3 +1,4 @@
from embedchain.config import AddConfig
from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.loaders.pdf_file import PdfFileLoader
from embedchain.loaders.web_page import WebPageLoader
@@ -18,11 +19,11 @@ class DataFormatter:
loaders and chunkers to the data_type entered by the user in their
.add or .add_local method call
"""
def __init__(self, data_type):
self.loader = self._get_loader(data_type)
self.chunker = self._get_chunker(data_type)
def _get_loader(self, data_type):
def __init__(self, data_type: str, config: AddConfig):
self.loader = self._get_loader(data_type, config.loader)
self.chunker = self._get_chunker(data_type, config.chunker)
def _get_loader(self, data_type, config):
"""
Returns the appropriate data loader for the given data type.
@@ -43,7 +44,7 @@ class DataFormatter:
else:
raise ValueError(f"Unsupported data type: {data_type}")
def _get_chunker(self, data_type):
def _get_chunker(self, data_type, config):
"""
Returns the appropriate chunker for the given data type.
@@ -52,15 +53,14 @@ class DataFormatter:
:raises ValueError: If an unsupported data type is provided.
"""
chunkers = {
'youtube_video': YoutubeVideoChunker(),
'pdf_file': PdfFileChunker(),
'web_page': WebPageChunker(),
'qna_pair': QnaPairChunker(),
'text': TextChunker(),
'docx': DocxFileChunker(),
'youtube_video': YoutubeVideoChunker(config),
'pdf_file': PdfFileChunker(config),
'web_page': WebPageChunker(config),
'qna_pair': QnaPairChunker(config),
'text': TextChunker(config),
'docx': DocxFileChunker(config),
}
if data_type in chunkers:
return chunkers[data_type]
else:
raise ValueError(f"Unsupported data type: {data_type}")

View File

@@ -37,9 +37,6 @@ class EmbedChain:
self.collection = self.config.db.collection
self.user_asks = []
def add(self, data_type, url, config: AddConfig = None):
"""
Adds the data from the given URL to the vector db.
@@ -52,8 +49,8 @@ class EmbedChain:
"""
if config is None:
config = AddConfig()
data_formatter = DataFormatter(data_type)
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, url])
self.load_and_embed(data_formatter.loader, data_formatter.chunker, url)
@@ -69,8 +66,8 @@ class EmbedChain:
"""
if config is None:
config = AddConfig()
data_formatter = DataFormatter(data_type)
data_formatter = DataFormatter(data_type, config)
self.user_asks.append([data_type, content])
self.load_and_embed(data_formatter.loader, data_formatter.chunker, content)