[Improvements] Add support for creating app from YAML string config (#980)

This commit is contained in:
Deven Patel
2023-11-29 12:25:30 -08:00
committed by GitHub
parent e35eaf1bfc
commit 406c46e7f4
34 changed files with 351 additions and 179 deletions

View File

@@ -12,7 +12,7 @@ from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.utils import validate_yaml_config
from embedchain.utils import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@@ -134,7 +134,7 @@ class App(EmbedChain):
config_data = yaml.safe_load(file)
try:
validate_yaml_config(config_data)
validate_config(config_data)
except Exception as e:
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")

View File

@@ -1,4 +1,5 @@
import importlib
import logging
import os
from typing import Optional
@@ -42,9 +43,11 @@ class HuggingFaceLlm(BaseLlm):
else:
raise ValueError("`top_p` must be > 0.0 and < 1.0")
model = config.model or "google/flan-t5-xxl"
logging.info(f"Using HuggingFaceHub with model {model}")
llm = HuggingFaceHub(
huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
repo_id=config.model or "google/flan-t5-xxl",
repo_id=model,
model_kwargs=model_kwargs,
)

View File

@@ -4,6 +4,7 @@ import logging
import os
import sqlite3
import uuid
from typing import Any, Dict, Optional
import requests
import yaml
@@ -19,7 +20,7 @@ from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import validate_yaml_config
from embedchain.utils import validate_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@@ -43,7 +44,7 @@ class Pipeline(EmbedChain):
db: BaseVectorDB = None,
embedding_model: BaseEmbedder = None,
llm: BaseLlm = None,
yaml_path: str = None,
config_data: dict = None,
log_level=logging.WARN,
auto_deploy: bool = False,
chunker: ChunkerConfig = None,
@@ -59,15 +60,15 @@ class Pipeline(EmbedChain):
:type embedding_model: BaseEmbedder, optional
:param llm: The LLM model used to calculate embeddings, defaults to None
:type llm: BaseLlm, optional
:param yaml_path: Path to the YAML configuration file, defaults to None
:type yaml_path: str, optional
:param config_data: Config dictionary, defaults to None
:type config_data: dict, optional
:param log_level: Log level to use, defaults to logging.WARN
:type log_level: int, optional
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:raises Exception: If an error occurs while creating the pipeline
"""
if id and yaml_path:
if id and config_data:
raise Exception("Cannot provide both id and config. Please provide only one of them.")
if id and name:
@@ -79,8 +80,8 @@ class Pipeline(EmbedChain):
logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
self.logger = logging.getLogger(__name__)
self.auto_deploy = auto_deploy
# Store the yaml config as an attribute to be able to send it
self.yaml_config = None
# Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None
self.client = None
# pipeline_id from the backend
self.id = None
@@ -92,11 +93,6 @@ class Pipeline(EmbedChain):
self.name = self.config.name
self.config.id = self.local_id = str(uuid.uuid4()) if self.config.id is None else self.config.id
if yaml_path:
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
self.yaml_config = config_data
if id is not None:
# Init client first since user is trying to fetch the pipeline
# details from the platform
@@ -187,9 +183,9 @@ class Pipeline(EmbedChain):
Create a pipeline on the platform.
"""
print("🛠️ Creating pipeline on the platform...")
# self.yaml_config is a dict. Pass it inside the key 'yaml_config' to the backend
# self.config_data is a dict. Pass it inside the key 'yaml_config' to the backend
payload = {
"yaml_config": json.dumps(self.yaml_config),
"yaml_config": json.dumps(self.config_data),
"name": self.name,
"local_id": self.local_id,
}
@@ -346,24 +342,57 @@ class Pipeline(EmbedChain):
self.telemetry.capture(event_name="deploy", properties=self._telemetry_props)
@classmethod
def from_config(cls, yaml_path: str, auto_deploy: bool = False):
def from_config(
cls,
config_path: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
auto_deploy: bool = False,
yaml_path: Optional[str] = None,
):
"""
Instantiate a Pipeline object from a YAML configuration file.
Instantiate a Pipeline object from a configuration.
:param yaml_path: Path to the YAML configuration file.
:type yaml_path: str
:param config_path: Path to the YAML or JSON configuration file.
:type config_path: Optional[str]
:param config: A dictionary containing the configuration.
:type config: Optional[Dict[str, Any]]
:param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
:type auto_deploy: bool, optional
:param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.
:type yaml_path: Optional[str]
:return: An instance of the Pipeline class.
:rtype: Pipeline
"""
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
# Backward compatibility for yaml_path
if yaml_path and not config_path:
config_path = yaml_path
if config_path and config:
raise ValueError("Please provide only one of config_path or config.")
config_data = None
if config_path:
file_extension = os.path.splitext(config_path)[1]
with open(config_path, "r") as file:
if file_extension in [".yaml", ".yml"]:
config_data = yaml.safe_load(file)
elif file_extension == ".json":
config_data = json.load(file)
else:
raise ValueError("config_path must be a path to a YAML or JSON file.")
elif config and isinstance(config, dict):
config_data = config
else:
logging.error(
"Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.", # noqa: E501
)
config_data = {}
try:
validate_yaml_config(config_data)
validate_config(config_data)
except Exception as e:
raise Exception(f"Error occurred while validating the YAML config. Error: {str(e)}")
raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
pipeline_config_data = config_data.get("app", {}).get("config", {})
db_config_data = config_data.get("vectordb", {})
@@ -388,7 +417,7 @@ class Pipeline(EmbedChain):
)
# Send anonymous telemetry
event_properties = {"init_type": "yaml_config"}
event_properties = {"init_type": "config_data"}
AnonymousTelemetry().capture(event_name="init", properties=event_properties)
return cls(
@@ -396,7 +425,7 @@ class Pipeline(EmbedChain):
llm=llm,
db=db,
embedding_model=embedding_model,
yaml_path=yaml_path,
config_data=config_data,
auto_deploy=auto_deploy,
chunker=chunker_config_data,
)

View File

@@ -165,7 +165,7 @@ class AIAssistant:
self.instructions = instructions
self.assistant_id = assistant_id or str(uuid.uuid4())
self.thread_id = thread_id or str(uuid.uuid4())
self.pipeline = Pipeline.from_config(yaml_path=yaml_path) if yaml_path else Pipeline()
self.pipeline = Pipeline.from_config(config_path=yaml_path) if yaml_path else Pipeline()
self.pipeline.local_id = self.pipeline.config.id = self.thread_id
if self.instructions:

View File

@@ -355,7 +355,7 @@ def is_valid_json_string(source: str):
return False
def validate_yaml_config(config_data):
def validate_config(config_data):
schema = Schema(
{
Optional("app"): {