[Improvements] Add support for creating app from YAML string config (#980)
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
@@ -19,7 +20,7 @@ from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import validate_yaml_config
|
||||
from embedchain.utils import validate_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
@@ -43,7 +44,7 @@ class Pipeline(EmbedChain):
|
||||
db: BaseVectorDB = None,
|
||||
embedding_model: BaseEmbedder = None,
|
||||
llm: BaseLlm = None,
|
||||
yaml_path: str = None,
|
||||
config_data: dict = None,
|
||||
log_level=logging.WARN,
|
||||
auto_deploy: bool = False,
|
||||
chunker: ChunkerConfig = None,
|
||||
@@ -59,15 +60,15 @@ class Pipeline(EmbedChain):
|
||||
:type embedding_model: BaseEmbedder, optional
|
||||
:param llm: The LLM model used to calculate embeddings, defaults to None
|
||||
:type llm: BaseLlm, optional
|
||||
:param yaml_path: Path to the YAML configuration file, defaults to None
|
||||
:type yaml_path: str, optional
|
||||
:param config_data: Config dictionary, defaults to None
|
||||
:type config_data: dict, optional
|
||||
:param log_level: Log level to use, defaults to logging.WARN
|
||||
:type log_level: int, optional
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:raises Exception: If an error occurs while creating the pipeline
|
||||
"""
|
||||
if id and yaml_path:
|
||||
if id and config_data:
|
||||
raise Exception("Cannot provide both id and config. Please provide only one of them.")
|
||||
|
||||
if id and name:
|
||||
@@ -79,8 +80,8 @@ class Pipeline(EmbedChain):
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.auto_deploy = auto_deploy
|
||||
# Store the yaml config as an attribute to be able to send it
|
||||
self.yaml_config = None
|
||||
# Store the dict config as an attribute to be able to send it
|
||||
self.config_data = config_data if (config_data and validate_config(config_data)) else None
|
||||
self.client = None
|
||||
# pipeline_id from the backend
|
||||
self.id = None
|
||||
@@ -92,11 +93,6 @@ class Pipeline(EmbedChain):
|
||||
self.name = self.config.name
|
||||
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
|
||||
|
||||
if yaml_path:
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
self.yaml_config = config_data
|
||||
|
||||
if id is not None:
|
||||
# Init client first since user is trying to fetch the pipeline
|
||||
# details from the platform
|
||||
@@ -187,9 +183,9 @@ class Pipeline(EmbedChain):
|
||||
Create a pipeline on the platform.
|
||||
"""
|
||||
print("🛠️ Creating pipeline on the platform...")
|
||||
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
|
||||
payload = {
|
||||
"yaml_config": json.dumps(self.yaml_config),
|
||||
"yaml_config": json.dumps(self.config_data),
|
||||
"name": self.name,
|
||||
"local_id": self.local_id,
|
||||
}
|
||||
@@ -346,24 +342,57 @@ class Pipeline(EmbedChain):
|
||||
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, yaml_path: str, auto_deploy: bool = False):
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
auto_deploy: bool = False,
|
||||
yaml_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Instantiate a Pipeline object from a YAML configuration file.
|
||||
Instantiate a Pipeline object from a configuration.
|
||||
|
||||
:param yaml_path: Path to the YAML configuration file.
|
||||
:type yaml_path: str
|
||||
:param config_path: Path to the YAML or JSON configuration file.
|
||||
:type config_path: Optional[str]
|
||||
:param config: A dictionary containing the configuration.
|
||||
:type config: Optional[Dict[str, Any]]
|
||||
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
|
||||
:type auto_deploy: bool, optional
|
||||
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
|
||||
:type yaml_path: Optional[str]
|
||||
:return: An instance of the Pipeline class.
|
||||
:rtype: Pipeline
|
||||
"""
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
# Backward compatibility for yaml_path
|
||||
if yaml_path and not config_path:
|
||||
config_path = yaml_path
|
||||
|
||||
if config_path and config:
|
||||
raise ValueError("Please provide only one of config_path or config.")
|
||||
|
||||
config_data = None
|
||||
|
||||
if config_path:
|
||||
file_extension = os.path.splitext(config_path)[1]
|
||||
with open(config_path, "r") as file:
|
||||
if file_extension in [".yaml", ".yml"]:
|
||||
config_data = yaml.safe_load(file)
|
||||
elif file_extension == ".json":
|
||||
config_data = json.load(file)
|
||||
else:
|
||||
raise ValueError("config_path must be a path to a YAML or JSON file.")
|
||||
elif config and isinstance(config, dict):
|
||||
config_data = config
|
||||
else:
|
||||
logging.error(
|
||||
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
|
||||
)
|
||||
config_data = {}
|
||||
|
||||
try:
|
||||
validate_yaml_config(config_data)
|
||||
validate_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
|
||||
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
|
||||
|
||||
pipeline_config_data = config_data.get("app", {}).get("config", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
@@ -388,7 +417,7 @@ class Pipeline(EmbedChain):
|
||||
)
|
||||
|
||||
# Send anonymous telemetry
|
||||
event_properties = {"init_type": "yaml_config"}
|
||||
event_properties = {"init_type": "config_data"}
|
||||
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
|
||||
|
||||
return cls(
|
||||
@@ -396,7 +425,7 @@ class Pipeline(EmbedChain):
|
||||
llm=llm,
|
||||
db=db,
|
||||
embedding_model=embedding_model,
|
||||
yaml_path=yaml_path,
|
||||
config_data=config_data,
|
||||
auto_deploy=auto_deploy,
|
||||
chunker=chunker_config_data,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user