From eda28cc49161864f49296f3ee583153e8acf7b41 Mon Sep 17 00:00:00 2001 From: Anupam Singh Date: Tue, 11 Jul 2023 04:23:56 +0530 Subject: [PATCH] featL AddConfig should allow configuring Chunker (#200) --- README.md | 41 +++++++++++++++++++-- embedchain/chunkers/docx_file.py | 9 ++++- embedchain/chunkers/pdf_file.py | 10 +++-- embedchain/chunkers/qna_pair.py | 8 +++- embedchain/chunkers/text.py | 8 +++- embedchain/chunkers/web_page.py | 8 +++- embedchain/chunkers/youtube_video.py | 10 +++-- embedchain/config/AddConfig.py | 28 +++++++++++++- embedchain/data_formatter/data_formatter.py | 26 ++++++------- embedchain/embedchain.py | 11 ++---- 10 files changed, 120 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 1f63ffc4..165a0c6b 100644 --- a/README.md +++ b/README.md @@ -377,8 +377,17 @@ config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction( )) naval_chat_bot = App(config) -add_config = AddConfig() # Currently no options -naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", add_config) +# Example: define your own chunker config for `youtube_video` +youtube_add_config = { + "chunker": { + "chunk_size": 1000, + "chunk_overlap": 100, + "length_function": len, + } +} +naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(**youtube_add_config)) + +add_config = AddConfig() naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config) naval_chat_bot.add("web_page", "https://nav.al/feedback", add_config) naval_chat_bot.add("web_page", "https://nav.al/agi", add_config) @@ -450,13 +459,39 @@ This section describes all possible config options. #### **Add Config** +|option|description|type|default| +|---|---|---|---| +|chunker|chunker config|ChunkerConfig|Default values for chunker depends on the `data_type`. Please refer [ChunkerConfig](#chunker-config)| +|loader|loader config|LoaderConfig|None| + +##### **Chunker Config** + +|option|description|type|default| +|---|---|---|---| +|chunk_size|Maximum size of chunks to return|int|Default value for various `data_type` mentioned below| +|chunk_overlap|Overlap in characters between chunks|int|Default value for various `data_type` mentioned below| +|length_function|Function that measures the length of given chunks|typing.Callable|Default value for various `data_type` mentioned below| + +Default values of chunker config parameters for different `data_type`: + +|data_type|chunk_size|chunk_overlap|length_function| +|---|---|---|---| +|docx|1000|0|len| +|text|300|0|len| +|qna_pair|300|0|len| +|web_page|500|0|len| +|pdf_file|1000|0|len| +|youtube_video|2000|0|len| + +##### **Loader Config** + _coming soon_ #### **Query Config** |option|description|type|default| |---|---|---|---| -|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: $query Helpful Answer:")| +|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: \$query Helpful Answer:")| |history|include conversation history from your client or database|any (recommendation: list[str])|None |stream|control if response is streamed back to the user|bool|False| diff --git a/embedchain/chunkers/docx_file.py b/embedchain/chunkers/docx_file.py index 03db7bea..55e186dd 100644 --- a/embedchain/chunkers/docx_file.py +++ b/embedchain/chunkers/docx_file.py @@ -1,8 +1,11 @@ +from typing import Optional from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig from langchain.text_splitter import RecursiveCharacterTextSplitter + TEXT_SPLITTER_CHUNK_PARAMS = { "chunk_size": 1000, "chunk_overlap": 0, @@ -11,6 +14,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = { class DocxFileChunker(BaseChunker): - def __init__(self): - text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS) + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = TEXT_SPLITTER_CHUNK_PARAMS + text_splitter = RecursiveCharacterTextSplitter(**config) super().__init__(text_splitter) diff --git a/embedchain/chunkers/pdf_file.py b/embedchain/chunkers/pdf_file.py index 47a23c7a..a6e1afcb 100644 --- a/embedchain/chunkers/pdf_file.py +++ b/embedchain/chunkers/pdf_file.py @@ -1,4 +1,6 @@ +from typing import Optional from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = { class PdfFileChunker(BaseChunker): - def __init__(self): - text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS) - super().__init__(text_splitter) \ No newline at end of file + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = TEXT_SPLITTER_CHUNK_PARAMS + text_splitter = RecursiveCharacterTextSplitter(**config) + super().__init__(text_splitter) diff --git a/embedchain/chunkers/qna_pair.py b/embedchain/chunkers/qna_pair.py index 7fe9e57b..f3352ea5 100644 --- a/embedchain/chunkers/qna_pair.py +++ b/embedchain/chunkers/qna_pair.py @@ -1,4 +1,6 @@ +from typing import Optional from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = { class QnaPairChunker(BaseChunker): - def __init__(self): - text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS) + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = TEXT_SPLITTER_CHUNK_PARAMS + text_splitter = RecursiveCharacterTextSplitter(**config) super().__init__(text_splitter) diff --git a/embedchain/chunkers/text.py b/embedchain/chunkers/text.py index bbf8e6b6..95fa8eee 100644 --- a/embedchain/chunkers/text.py +++ b/embedchain/chunkers/text.py @@ -1,4 +1,6 @@ +from typing import Optional from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = { class TextChunker(BaseChunker): - def __init__(self): - text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS) + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = TEXT_SPLITTER_CHUNK_PARAMS + text_splitter = RecursiveCharacterTextSplitter(**config) super().__init__(text_splitter) diff --git a/embedchain/chunkers/web_page.py b/embedchain/chunkers/web_page.py index fd308ccb..a442556f 100644 --- a/embedchain/chunkers/web_page.py +++ b/embedchain/chunkers/web_page.py @@ -1,4 +1,6 @@ +from typing import Optional from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = { class WebPageChunker(BaseChunker): - def __init__(self): - text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS) + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = TEXT_SPLITTER_CHUNK_PARAMS + text_splitter = RecursiveCharacterTextSplitter(**config) super().__init__(text_splitter) diff --git a/embedchain/chunkers/youtube_video.py b/embedchain/chunkers/youtube_video.py index 7435c02d..a1406ca7 100644 --- a/embedchain/chunkers/youtube_video.py +++ b/embedchain/chunkers/youtube_video.py @@ -1,4 +1,6 @@ +from typing import Optional from embedchain.chunkers.base_chunker import BaseChunker +from embedchain.config.AddConfig import ChunkerConfig from langchain.text_splitter import RecursiveCharacterTextSplitter @@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = { class YoutubeVideoChunker(BaseChunker): - def __init__(self): - text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS) - super().__init__(text_splitter) \ No newline at end of file + def __init__(self, config: Optional[ChunkerConfig] = None): + if config is None: + config = TEXT_SPLITTER_CHUNK_PARAMS + text_splitter = RecursiveCharacterTextSplitter(**config) + super().__init__(text_splitter) diff --git a/embedchain/config/AddConfig.py b/embedchain/config/AddConfig.py index f72ae952..f0e7b434 100644 --- a/embedchain/config/AddConfig.py +++ b/embedchain/config/AddConfig.py @@ -1,8 +1,32 @@ +from typing import Callable, Optional from embedchain.config.BaseConfig import BaseConfig + +class ChunkerConfig(BaseConfig): + """ + Config for the chunker used in `add` method + """ + def __init__(self, + chunk_size: Optional[int] = 4000, + chunk_overlap: Optional[int] = 200, + length_function: Optional[Callable[[str], int]] = len): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.length_function = length_function + +class LoaderConfig(BaseConfig): + """ + Config for the chunker used in `add` method + """ + def __init__(self): + pass + class AddConfig(BaseConfig): """ Config for the `add` method. """ - def __init__(self): - pass \ No newline at end of file + def __init__(self, + chunker: Optional[ChunkerConfig] = None, + loader: Optional[LoaderConfig] = None): + self.loader = loader + self.chunker = chunker diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index 0cbf34f4..a6165da6 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -1,3 +1,4 @@ +from embedchain.config import AddConfig from embedchain.loaders.youtube_video import YoutubeVideoLoader from embedchain.loaders.pdf_file import PdfFileLoader from embedchain.loaders.web_page import WebPageLoader @@ -18,11 +19,11 @@ class DataFormatter: loaders and chunkers to the data_type entered by the user in their .add or .add_local method call """ - def __init__(self, data_type): - self.loader = self._get_loader(data_type) - self.chunker = self._get_chunker(data_type) - - def _get_loader(self, data_type): + def __init__(self, data_type: str, config: AddConfig): + self.loader = self._get_loader(data_type, config.loader) + self.chunker = self._get_chunker(data_type, config.chunker) + + def _get_loader(self, data_type, config): """ Returns the appropriate data loader for the given data type. @@ -43,7 +44,7 @@ class DataFormatter: else: raise ValueError(f"Unsupported data type: {data_type}") - def _get_chunker(self, data_type): + def _get_chunker(self, data_type, config): """ Returns the appropriate chunker for the given data type. @@ -52,15 +53,14 @@ class DataFormatter: :raises ValueError: If an unsupported data type is provided. """ chunkers = { - 'youtube_video': YoutubeVideoChunker(), - 'pdf_file': PdfFileChunker(), - 'web_page': WebPageChunker(), - 'qna_pair': QnaPairChunker(), - 'text': TextChunker(), - 'docx': DocxFileChunker(), + 'youtube_video': YoutubeVideoChunker(config), + 'pdf_file': PdfFileChunker(config), + 'web_page': WebPageChunker(config), + 'qna_pair': QnaPairChunker(config), + 'text': TextChunker(config), + 'docx': DocxFileChunker(config), } if data_type in chunkers: return chunkers[data_type] else: raise ValueError(f"Unsupported data type: {data_type}") - diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index e219d348..3685f328 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -37,9 +37,6 @@ class EmbedChain: self.collection = self.config.db.collection self.user_asks = [] - - - def add(self, data_type, url, config: AddConfig = None): """ Adds the data from the given URL to the vector db. @@ -52,8 +49,8 @@ class EmbedChain: """ if config is None: config = AddConfig() - - data_formatter = DataFormatter(data_type) + + data_formatter = DataFormatter(data_type, config) self.user_asks.append([data_type, url]) self.load_and_embed(data_formatter.loader, data_formatter.chunker, url) @@ -69,8 +66,8 @@ class EmbedChain: """ if config is None: config = AddConfig() - - data_formatter = DataFormatter(data_type) + + data_formatter = DataFormatter(data_type, config) self.user_asks.append([data_type, content]) self.load_and_embed(data_formatter.loader, data_formatter.chunker, content)