Introduce chunker config in yaml config (#907)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
|
||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
@@ -38,6 +38,7 @@ class App(EmbedChain):
|
||||
embedder: BaseEmbedder = None,
|
||||
embedder_config: Optional[BaseEmbedderConfig] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
chunker: Optional[ChunkerConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
@@ -97,6 +98,9 @@ class App(EmbedChain):
|
||||
if embedder is None:
|
||||
embedder = OpenAIEmbedder(config=embedder_config)
|
||||
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
# Type check assignments
|
||||
if not isinstance(llm, BaseLlm):
|
||||
raise TypeError(
|
||||
@@ -137,6 +141,7 @@ class App(EmbedChain):
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
app_config = AppConfig(**app_config_data.get("config", {}))
|
||||
|
||||
@@ -148,4 +153,4 @@ class App(EmbedChain):
|
||||
|
||||
embedder_provider = embedding_model_config_data.get("provider", "openai")
|
||||
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
|
||||
return cls(config=app_config, llm=llm, db=db, embedder=embedder)
|
||||
return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import builtins
|
||||
from importlib import import_module
|
||||
from typing import Callable, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
@@ -18,7 +20,18 @@ class ChunkerConfig(BaseConfig):
|
||||
):
|
||||
self.chunk_size = chunk_size if chunk_size else 2000
|
||||
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
|
||||
self.length_function = length_function if length_function else len
|
||||
if isinstance(length_function, str):
|
||||
self.length_function = self.load_func(length_function)
|
||||
else:
|
||||
self.length_function = length_function if length_function else len
|
||||
|
||||
def load_func(self, dotpath: str):
|
||||
if "." not in dotpath:
|
||||
return getattr(builtins, dotpath)
|
||||
else:
|
||||
module_, func = dotpath.rsplit(".", maxsplit=1)
|
||||
m = import_module(module_)
|
||||
return getattr(m, func)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -10,15 +10,14 @@ from dotenv import load_dotenv
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig, BaseLlmConfig
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||
from embedchain.config.apps.base_app_config import BaseAppConfig
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
IndirectDataType, SpecialDataType)
|
||||
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import detect_datatype, is_valid_json_string
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
@@ -84,6 +83,7 @@ class EmbedChain(JSONSerializable):
|
||||
# Attributes that aren't subclass related.
|
||||
self.user_asks = []
|
||||
|
||||
self.chunker: ChunkerConfig = None
|
||||
# Send anonymous telemetry
|
||||
self._telemetry_props = {"class": self.__class__.__name__}
|
||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||
@@ -157,7 +157,11 @@ class EmbedChain(JSONSerializable):
|
||||
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||
:rtype: str
|
||||
"""
|
||||
if config is None:
|
||||
if config is not None:
|
||||
pass
|
||||
elif self.chunker is not None:
|
||||
config = AddConfig(chunker=self.chunker)
|
||||
else:
|
||||
config = AddConfig()
|
||||
|
||||
try:
|
||||
|
||||
@@ -9,7 +9,7 @@ import requests
|
||||
import yaml
|
||||
|
||||
from embedchain import Client
|
||||
from embedchain.config import PipelineConfig
|
||||
from embedchain.config import PipelineConfig, ChunkerConfig
|
||||
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
@@ -44,6 +44,7 @@ class Pipeline(EmbedChain):
|
||||
yaml_path: str = None,
|
||||
log_level=logging.INFO,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
):
|
||||
"""
|
||||
Initialize a new `App` instance.
|
||||
@@ -84,6 +85,10 @@ class Pipeline(EmbedChain):
|
||||
# pipeline_id from the backend
|
||||
self.id = None
|
||||
|
||||
self.chunker = None
|
||||
if chunker:
|
||||
self.chunker = ChunkerConfig(**chunker)
|
||||
|
||||
self.config = config or PipelineConfig()
|
||||
self.name = self.config.name
|
||||
|
||||
@@ -366,6 +371,7 @@ class Pipeline(EmbedChain):
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
chunker_config_data = config_data.get("chunker", {})
|
||||
|
||||
pipeline_config = PipelineConfig(**pipeline_config_data)
|
||||
|
||||
@@ -394,4 +400,5 @@ class Pipeline(EmbedChain):
|
||||
embedding_model=embedding_model,
|
||||
yaml_path=yaml_path,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
)
|
||||
|
||||
@@ -350,6 +350,11 @@ def validate_yaml_config(config_data):
|
||||
Optional("deployment_name"): str,
|
||||
},
|
||||
},
|
||||
Optional("chunker"): {
|
||||
Optional("chunk_size"): int,
|
||||
Optional("chunk_overlap"): int,
|
||||
Optional("length_function"): str,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user