Introduce chunker config in yaml config (#907)

This commit is contained in:
Sidharth Mohanty
2023-11-06 23:13:15 +05:30
committed by GitHub
parent f0d112254b
commit a1de238716
5 changed files with 42 additions and 8 deletions

View File

@@ -2,7 +2,7 @@ from typing import Optional
import yaml
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
@@ -38,6 +38,7 @@ class App(EmbedChain):
embedder: BaseEmbedder = None,
embedder_config: Optional[BaseEmbedderConfig] = None,
system_prompt: Optional[str] = None,
chunker: Optional[ChunkerConfig] = None,
):
"""
Initialize a new `App` instance.
@@ -97,6 +98,9 @@ class App(EmbedChain):
if embedder is None:
embedder = OpenAIEmbedder(config=embedder_config)
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
# Type check assignments
if not isinstance(llm, BaseLlm):
raise TypeError(
@@ -137,6 +141,7 @@ class App(EmbedChain):
llm_config_data = config_data.get("llm", {})
db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
chunker_config_data = config_data.get("chunker", {})
app_config = AppConfig(**app_config_data.get("config", {}))
@@ -148,4 +153,4 @@ class App(EmbedChain):
embedder_provider = embedding_model_config_data.get("provider", "openai")
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
return cls(config=app_config, llm=llm, db=db, embedder=embedder)
return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)

View File

@@ -1,3 +1,5 @@
import builtins
from importlib import import_module
from typing import Callable, Optional
from embedchain.config.base_config import BaseConfig
@@ -18,7 +20,18 @@ class ChunkerConfig(BaseConfig):
):
self.chunk_size = chunk_size if chunk_size else 2000
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
self.length_function = length_function if length_function else len
if isinstance(length_function, str):
self.length_function = self.load_func(length_function)
else:
self.length_function = length_function if length_function else len
def load_func(self, dotpath: str):
if "." not in dotpath:
return getattr(builtins, dotpath)
else:
module_, func = dotpath.rsplit(".", maxsplit=1)
m = import_module(module_)
return getattr(m, func)
@register_deserializable

View File

@@ -10,15 +10,14 @@ from dotenv import load_dotenv
from langchain.docstore.document import Document
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.apps.base_app_config import BaseAppConfig
from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB
@@ -84,6 +83,7 @@ class EmbedChain(JSONSerializable):
# Attributes that aren't subclass related.
self.user_asks = []
self.chunker: ChunkerConfig = None
# Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -157,7 +157,11 @@ class EmbedChain(JSONSerializable):
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
:rtype: str
"""
if config is None:
if config is not None:
pass
elif self.chunker is not None:
config = AddConfig(chunker=self.chunker)
else:
config = AddConfig()
try:

View File

@@ -9,7 +9,7 @@ import requests
import yaml
from embedchain import Client
from embedchain.config import PipelineConfig
from embedchain.config import PipelineConfig, ChunkerConfig
from embedchain.embedchain import CONFIG_DIR, EmbedChain
from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder
@@ -44,6 +44,7 @@ class Pipeline(EmbedChain):
yaml_path: str = None,
log_level=logging.INFO,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
):
"""
Initialize a new `App` instance.
@@ -84,6 +85,10 @@ class Pipeline(EmbedChain):
# pipeline_id from the backend
self.id = None
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
self.config = config or PipelineConfig()
self.name = self.config.name
@@ -366,6 +371,7 @@ class Pipeline(EmbedChain):
db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
pipeline_config = PipelineConfig(**pipeline_config_data)
@@ -394,4 +400,5 @@ class Pipeline(EmbedChain):
embedding_model=embedding_model,
yaml_path=yaml_path,
auto_deploy=auto_deploy,
chunker=chunker_config_data,
)

View File

@@ -350,6 +350,11 @@ def validate_yaml_config(config_data):
Optional("deployment_name"): str,
},
},
Optional("chunker"): {
Optional("chunk_size"): int,
Optional("chunk_overlap"): int,
Optional("length_function"): str,
},
}
)