diff --git a/embedchain/apps/app.py b/embedchain/apps/app.py index 8bb6e5bf..1508027b 100644 --- a/embedchain/apps/app.py +++ b/embedchain/apps/app.py @@ -2,7 +2,7 @@ from typing import Optional import yaml -from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig +from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.embedchain import EmbedChain from embedchain.embedder.base import BaseEmbedder @@ -38,6 +38,7 @@ class App(EmbedChain): embedder: BaseEmbedder = None, embedder_config: Optional[BaseEmbedderConfig] = None, system_prompt: Optional[str] = None, + chunker: Optional[ChunkerConfig] = None, ): """ Initialize a new `App` instance. @@ -97,6 +98,9 @@ class App(EmbedChain): if embedder is None: embedder = OpenAIEmbedder(config=embedder_config) + self.chunker = None + if chunker: + self.chunker = ChunkerConfig(**chunker) # Type check assignments if not isinstance(llm, BaseLlm): raise TypeError( @@ -137,6 +141,7 @@ class App(EmbedChain): llm_config_data = config_data.get("llm", {}) db_config_data = config_data.get("vectordb", {}) embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {})) + chunker_config_data = config_data.get("chunker", {}) app_config = AppConfig(**app_config_data.get("config", {})) @@ -148,4 +153,4 @@ class App(EmbedChain): embedder_provider = embedding_model_config_data.get("provider", "openai") embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {})) - return cls(config=app_config, llm=llm, db=db, embedder=embedder) + return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data) diff --git a/embedchain/config/add_config.py b/embedchain/config/add_config.py index 945ec09b..16e99d68 100644 --- a/embedchain/config/add_config.py +++ b/embedchain/config/add_config.py @@ -1,3 +1,5 @@ +import builtins +from importlib import import_module from typing import Callable, Optional from embedchain.config.base_config import BaseConfig @@ -18,7 +20,18 @@ class ChunkerConfig(BaseConfig): ): self.chunk_size = chunk_size if chunk_size else 2000 self.chunk_overlap = chunk_overlap if chunk_overlap else 0 - self.length_function = length_function if length_function else len + if isinstance(length_function, str): + self.length_function = self.load_func(length_function) + else: + self.length_function = length_function if length_function else len + + def load_func(self, dotpath: str): + if "." not in dotpath: + return getattr(builtins, dotpath) + else: + module_, func = dotpath.rsplit(".", maxsplit=1) + m = import_module(module_) + return getattr(m, func) @register_deserializable diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 0577bc3b..eb0dc62a 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -10,15 +10,14 @@ from dotenv import load_dotenv from langchain.docstore.document import Document from embedchain.chunkers.base_chunker import BaseChunker -from embedchain.config import AddConfig, BaseLlmConfig +from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.apps.base_app_config import BaseAppConfig from embedchain.data_formatter import DataFormatter from embedchain.embedder.base import BaseEmbedder from embedchain.helper.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import (DataType, DirectDataType, - IndirectDataType, SpecialDataType) +from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.vectordb.base import BaseVectorDB @@ -84,6 +83,7 @@ class EmbedChain(JSONSerializable): # Attributes that aren't subclass related. self.user_asks = [] + self.chunker: ChunkerConfig = None # Send anonymous telemetry self._telemetry_props = {"class": self.__class__.__name__} self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics) @@ -157,7 +157,11 @@ class EmbedChain(JSONSerializable): :return: source_hash, a md5-hash of the source, in hexadecimal representation. :rtype: str """ - if config is None: + if config is not None: + pass + elif self.chunker is not None: + config = AddConfig(chunker=self.chunker) + else: config = AddConfig() try: diff --git a/embedchain/pipeline.py b/embedchain/pipeline.py index 7278f6e9..64e38e8b 100644 --- a/embedchain/pipeline.py +++ b/embedchain/pipeline.py @@ -9,7 +9,7 @@ import requests import yaml from embedchain import Client -from embedchain.config import PipelineConfig +from embedchain.config import PipelineConfig, ChunkerConfig from embedchain.embedchain import CONFIG_DIR, EmbedChain from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.openai import OpenAIEmbedder @@ -44,6 +44,7 @@ class Pipeline(EmbedChain): yaml_path: str = None, log_level=logging.INFO, auto_deploy: bool = False, + chunker: ChunkerConfig = None, ): """ Initialize a new `App` instance. @@ -84,6 +85,10 @@ class Pipeline(EmbedChain): # pipeline_id from the backend self.id = None + self.chunker = None + if chunker: + self.chunker = ChunkerConfig(**chunker) + self.config = config or PipelineConfig() self.name = self.config.name @@ -366,6 +371,7 @@ class Pipeline(EmbedChain): db_config_data = config_data.get("vectordb", {}) embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {})) llm_config_data = config_data.get("llm", {}) + chunker_config_data = config_data.get("chunker", {}) pipeline_config = PipelineConfig(**pipeline_config_data) @@ -394,4 +400,5 @@ class Pipeline(EmbedChain): embedding_model=embedding_model, yaml_path=yaml_path, auto_deploy=auto_deploy, + chunker=chunker_config_data, ) diff --git a/embedchain/utils.py b/embedchain/utils.py index 87eed4d3..99a01f8b 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -350,6 +350,11 @@ def validate_yaml_config(config_data): Optional("deployment_name"): str, }, }, + Optional("chunker"): { + Optional("chunk_size"): int, + Optional("chunk_overlap"): int, + Optional("length_function"): str, + }, } )