Introduce chunker config in yaml config (#907)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
|
from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
|
||||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
@@ -38,6 +38,7 @@ class App(EmbedChain):
|
|||||||
embedder: BaseEmbedder = None,
|
embedder: BaseEmbedder = None,
|
||||||
embedder_config: Optional[BaseEmbedderConfig] = None,
|
embedder_config: Optional[BaseEmbedderConfig] = None,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
|
chunker: Optional[ChunkerConfig] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new `App` instance.
|
Initialize a new `App` instance.
|
||||||
@@ -97,6 +98,9 @@ class App(EmbedChain):
|
|||||||
if embedder is None:
|
if embedder is None:
|
||||||
embedder = OpenAIEmbedder(config=embedder_config)
|
embedder = OpenAIEmbedder(config=embedder_config)
|
||||||
|
|
||||||
|
self.chunker = None
|
||||||
|
if chunker:
|
||||||
|
self.chunker = ChunkerConfig(**chunker)
|
||||||
# Type check assignments
|
# Type check assignments
|
||||||
if not isinstance(llm, BaseLlm):
|
if not isinstance(llm, BaseLlm):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@@ -137,6 +141,7 @@ class App(EmbedChain):
|
|||||||
llm_config_data = config_data.get("llm", {})
|
llm_config_data = config_data.get("llm", {})
|
||||||
db_config_data = config_data.get("vectordb", {})
|
db_config_data = config_data.get("vectordb", {})
|
||||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||||
|
chunker_config_data = config_data.get("chunker", {})
|
||||||
|
|
||||||
app_config = AppConfig(**app_config_data.get("config", {}))
|
app_config = AppConfig(**app_config_data.get("config", {}))
|
||||||
|
|
||||||
@@ -148,4 +153,4 @@ class App(EmbedChain):
|
|||||||
|
|
||||||
embedder_provider = embedding_model_config_data.get("provider", "openai")
|
embedder_provider = embedding_model_config_data.get("provider", "openai")
|
||||||
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
|
embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
|
||||||
return cls(config=app_config, llm=llm, db=db, embedder=embedder)
|
return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import builtins
|
||||||
|
from importlib import import_module
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from embedchain.config.base_config import BaseConfig
|
from embedchain.config.base_config import BaseConfig
|
||||||
@@ -18,8 +20,19 @@ class ChunkerConfig(BaseConfig):
|
|||||||
):
|
):
|
||||||
self.chunk_size = chunk_size if chunk_size else 2000
|
self.chunk_size = chunk_size if chunk_size else 2000
|
||||||
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
|
self.chunk_overlap = chunk_overlap if chunk_overlap else 0
|
||||||
|
if isinstance(length_function, str):
|
||||||
|
self.length_function = self.load_func(length_function)
|
||||||
|
else:
|
||||||
self.length_function = length_function if length_function else len
|
self.length_function = length_function if length_function else len
|
||||||
|
|
||||||
|
def load_func(self, dotpath: str):
|
||||||
|
if "." not in dotpath:
|
||||||
|
return getattr(builtins, dotpath)
|
||||||
|
else:
|
||||||
|
module_, func = dotpath.rsplit(".", maxsplit=1)
|
||||||
|
m = import_module(module_)
|
||||||
|
return getattr(m, func)
|
||||||
|
|
||||||
|
|
||||||
@register_deserializable
|
@register_deserializable
|
||||||
class LoaderConfig(BaseConfig):
|
class LoaderConfig(BaseConfig):
|
||||||
|
|||||||
@@ -10,15 +10,14 @@ from dotenv import load_dotenv
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config import AddConfig, BaseLlmConfig
|
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
|
||||||
from embedchain.config.apps.base_app_config import BaseAppConfig
|
from embedchain.config.apps.base_app_config import BaseAppConfig
|
||||||
from embedchain.data_formatter import DataFormatter
|
from embedchain.data_formatter import DataFormatter
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.helper.json_serializable import JSONSerializable
|
from embedchain.helper.json_serializable import JSONSerializable
|
||||||
from embedchain.llm.base import BaseLlm
|
from embedchain.llm.base import BaseLlm
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
|
||||||
IndirectDataType, SpecialDataType)
|
|
||||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||||
from embedchain.utils import detect_datatype, is_valid_json_string
|
from embedchain.utils import detect_datatype, is_valid_json_string
|
||||||
from embedchain.vectordb.base import BaseVectorDB
|
from embedchain.vectordb.base import BaseVectorDB
|
||||||
@@ -84,6 +83,7 @@ class EmbedChain(JSONSerializable):
|
|||||||
# Attributes that aren't subclass related.
|
# Attributes that aren't subclass related.
|
||||||
self.user_asks = []
|
self.user_asks = []
|
||||||
|
|
||||||
|
self.chunker: ChunkerConfig = None
|
||||||
# Send anonymous telemetry
|
# Send anonymous telemetry
|
||||||
self._telemetry_props = {"class": self.__class__.__name__}
|
self._telemetry_props = {"class": self.__class__.__name__}
|
||||||
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
|
||||||
@@ -157,7 +157,11 @@ class EmbedChain(JSONSerializable):
|
|||||||
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
:return: source_hash, a md5-hash of the source, in hexadecimal representation.
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is not None:
|
||||||
|
pass
|
||||||
|
elif self.chunker is not None:
|
||||||
|
config = AddConfig(chunker=self.chunker)
|
||||||
|
else:
|
||||||
config = AddConfig()
|
config = AddConfig()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import requests
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from embedchain import Client
|
from embedchain import Client
|
||||||
from embedchain.config import PipelineConfig
|
from embedchain.config import PipelineConfig, ChunkerConfig
|
||||||
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
from embedchain.embedchain import CONFIG_DIR, EmbedChain
|
||||||
from embedchain.embedder.base import BaseEmbedder
|
from embedchain.embedder.base import BaseEmbedder
|
||||||
from embedchain.embedder.openai import OpenAIEmbedder
|
from embedchain.embedder.openai import OpenAIEmbedder
|
||||||
@@ -44,6 +44,7 @@ class Pipeline(EmbedChain):
|
|||||||
yaml_path: str = None,
|
yaml_path: str = None,
|
||||||
log_level=logging.INFO,
|
log_level=logging.INFO,
|
||||||
auto_deploy: bool = False,
|
auto_deploy: bool = False,
|
||||||
|
chunker: ChunkerConfig = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize a new `App` instance.
|
Initialize a new `App` instance.
|
||||||
@@ -84,6 +85,10 @@ class Pipeline(EmbedChain):
|
|||||||
# pipeline_id from the backend
|
# pipeline_id from the backend
|
||||||
self.id = None
|
self.id = None
|
||||||
|
|
||||||
|
self.chunker = None
|
||||||
|
if chunker:
|
||||||
|
self.chunker = ChunkerConfig(**chunker)
|
||||||
|
|
||||||
self.config = config or PipelineConfig()
|
self.config = config or PipelineConfig()
|
||||||
self.name = self.config.name
|
self.name = self.config.name
|
||||||
|
|
||||||
@@ -366,6 +371,7 @@ class Pipeline(EmbedChain):
|
|||||||
db_config_data = config_data.get("vectordb", {})
|
db_config_data = config_data.get("vectordb", {})
|
||||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||||
llm_config_data = config_data.get("llm", {})
|
llm_config_data = config_data.get("llm", {})
|
||||||
|
chunker_config_data = config_data.get("chunker", {})
|
||||||
|
|
||||||
pipeline_config = PipelineConfig(**pipeline_config_data)
|
pipeline_config = PipelineConfig(**pipeline_config_data)
|
||||||
|
|
||||||
@@ -394,4 +400,5 @@ class Pipeline(EmbedChain):
|
|||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
yaml_path=yaml_path,
|
yaml_path=yaml_path,
|
||||||
auto_deploy=auto_deploy,
|
auto_deploy=auto_deploy,
|
||||||
|
chunker=chunker_config_data,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -350,6 +350,11 @@ def validate_yaml_config(config_data):
|
|||||||
Optional("deployment_name"): str,
|
Optional("deployment_name"): str,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Optional("chunker"): {
|
||||||
|
Optional("chunk_size"): int,
|
||||||
|
Optional("chunk_overlap"): int,
|
||||||
|
Optional("length_function"): str,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user