Introduce chunker config in yaml config (#907)

This commit is contained in:
Sidharth Mohanty
2023-11-06 23:13:15 +05:30
committed by GitHub
parent f0d112254b
commit a1de238716
5 changed files with 42 additions and 8 deletions

View File

@@ -2,7 +2,7 @@ from typing import Optional
import yaml import yaml
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedchain import EmbedChain from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
@@ -38,6 +38,7 @@ class App(EmbedChain):
embedder: BaseEmbedder = None, embedder: BaseEmbedder = None,
embedder_config: Optional[BaseEmbedderConfig] = None, embedder_config: Optional[BaseEmbedderConfig] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
chunker: Optional[ChunkerConfig] = None,
): ):
""" """
Initialize a new `App` instance. Initialize a new `App` instance.
@@ -97,6 +98,9 @@ class App(EmbedChain):
if embedder is None: if embedder is None:
embedder = OpenAIEmbedder(config=embedder_config) embedder = OpenAIEmbedder(config=embedder_config)
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
# Type check assignments # Type check assignments
if not isinstance(llm, BaseLlm): if not isinstance(llm, BaseLlm):
raise TypeError( raise TypeError(
@@ -137,6 +141,7 @@ class App(EmbedChain):
llm_config_data = config_data.get("llm", {}) llm_config_data = config_data.get("llm", {})
db_config_data = config_data.get("vectordb", {}) db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {})) embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
chunker_config_data = config_data.get("chunker", {})
app_config = AppConfig(**app_config_data.get("config", {})) app_config = AppConfig(**app_config_data.get("config", {}))
@@ -148,4 +153,4 @@ class App(EmbedChain):
embedder_provider = embedding_model_config_data.get("provider", "openai") embedder_provider = embedding_model_config_data.get("provider", "openai")
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {})) embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
return cls(config=app_config, llm=llm, db=db, embedder=embedder) return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)

View File

@@ -1,3 +1,5 @@
import builtins
from importlib import import_module
from typing import Callable, Optional from typing import Callable, Optional
from embedchain.config.base_config import BaseConfig from embedchain.config.base_config import BaseConfig
@@ -18,7 +20,18 @@ class ChunkerConfig(BaseConfig):
): ):
self.chunk_size = chunk_size if chunk_size else 2000 self.chunk_size = chunk_size if chunk_size else 2000
self.chunk_overlap = chunk_overlap if chunk_overlap else 0 self.chunk_overlap = chunk_overlap if chunk_overlap else 0
self.length_function = length_function if length_function else len if isinstance(length_function, str):
self.length_function = self.load_func(length_function)
else:
self.length_function = length_function if length_function else len
def load_func(self, dotpath: str):
if "." not in dotpath:
return getattr(builtins, dotpath)
else:
module_, func = dotpath.rsplit(".", maxsplit=1)
m = import_module(module_)
return getattr(m, func)
@register_deserializable @register_deserializable

View File

@@ -10,15 +10,14 @@ from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.apps.base_app_config import BaseAppConfig from embedchain.config.apps.base_app_config import BaseAppConfig
from embedchain.data_formatter import DataFormatter from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helper.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType, from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
@@ -84,6 +83,7 @@ class EmbedChain(JSONSerializable):
# Attributes that aren't subclass related. # Attributes that aren't subclass related.
self.user_asks = [] self.user_asks = []
self.chunker: ChunkerConfig = None
# Send anonymous telemetry # Send anonymous telemetry
self._telemetry_props = {"class": self.__class__.__name__} self._telemetry_props = {"class": self.__class__.__name__}
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics) self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -157,7 +157,11 @@ class EmbedChain(JSONSerializable):
:return: source_hash, a md5-hash of the source, in hexadecimal representation. :return: source_hash, a md5-hash of the source, in hexadecimal representation.
:rtype: str :rtype: str
""" """
if config is None: if config is not None:
pass
elif self.chunker is not None:
config = AddConfig(chunker=self.chunker)
else:
config = AddConfig() config = AddConfig()
try: try:

View File

@@ -9,7 +9,7 @@ import requests
import yaml import yaml
from embedchain import Client from embedchain import Client
from embedchain.config import PipelineConfig from embedchain.config import PipelineConfig, ChunkerConfig
from embedchain.embedchain import CONFIG_DIR, EmbedChain from embedchain.embedchain import CONFIG_DIR, EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
@@ -44,6 +44,7 @@ class Pipeline(EmbedChain):
yaml_path: str = None, yaml_path: str = None,
log_level=logging.INFO, log_level=logging.INFO,
auto_deploy: bool = False, auto_deploy: bool = False,
chunker: ChunkerConfig = None,
): ):
""" """
Initialize a new `App` instance. Initialize a new `App` instance.
@@ -84,6 +85,10 @@ class Pipeline(EmbedChain):
# pipeline_id from the backend # pipeline_id from the backend
self.id = None self.id = None
self.chunker = None
if chunker:
self.chunker = ChunkerConfig(**chunker)
self.config = config or PipelineConfig() self.config = config or PipelineConfig()
self.name = self.config.name self.name = self.config.name
@@ -366,6 +371,7 @@ class Pipeline(EmbedChain):
db_config_data = config_data.get("vectordb", {}) db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {})) embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
llm_config_data = config_data.get("llm", {}) llm_config_data = config_data.get("llm", {})
chunker_config_data = config_data.get("chunker", {})
pipeline_config = PipelineConfig(**pipeline_config_data) pipeline_config = PipelineConfig(**pipeline_config_data)
@@ -394,4 +400,5 @@ class Pipeline(EmbedChain):
embedding_model=embedding_model, embedding_model=embedding_model,
yaml_path=yaml_path, yaml_path=yaml_path,
auto_deploy=auto_deploy, auto_deploy=auto_deploy,
chunker=chunker_config_data,
) )

View File

@@ -350,6 +350,11 @@ def validate_yaml_config(config_data):
Optional("deployment_name"): str, Optional("deployment_name"): str,
}, },
}, },
Optional("chunker"): {
Optional("chunk_size"): int,
Optional("chunk_overlap"): int,
Optional("length_function"): str,
},
} }
) )