Add yaml config validation (#890)
This commit is contained in:
@@ -11,6 +11,7 @@ from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.utils import validate_yaml_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
@@ -127,6 +128,11 @@ class App(EmbedChain):
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
try:
|
||||
validate_yaml_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
|
||||
|
||||
app_config_data = config_data.get("app", {})
|
||||
llm_config_data = config_data.get("llm", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
|
||||
@@ -19,6 +19,7 @@ from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
from embedchain.utils import validate_yaml_config
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
@@ -357,6 +358,11 @@ class Pipeline(EmbedChain):
|
||||
with open(yaml_path, "r") as file:
|
||||
config_data = yaml.safe_load(file)
|
||||
|
||||
try:
|
||||
validate_yaml_config(config_data)
|
||||
except Exception as e:
|
||||
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
|
||||
|
||||
pipeline_config_data = config_data.get("app", {}).get("config", {})
|
||||
db_config_data = config_data.get("vectordb", {})
|
||||
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
|
||||
|
||||
@@ -5,6 +5,8 @@ import re
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
from schema import Optional, Or, Schema
|
||||
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
@@ -136,8 +138,7 @@ def detect_datatype(source: Any) -> DataType:
|
||||
formatted_source = format_source(str(source), 30)
|
||||
|
||||
if url:
|
||||
from langchain.document_loaders.youtube import \
|
||||
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
|
||||
|
||||
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
|
||||
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
|
||||
@@ -283,3 +284,73 @@ def is_valid_json_string(source: str):
|
||||
Check the docs to see the supported formats - `https://docs.embedchain.ai/data-sources/json`"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def validate_yaml_config(config_data):
|
||||
schema = Schema(
|
||||
{
|
||||
Optional("app"): {
|
||||
Optional("config"): {
|
||||
Optional("id"): str,
|
||||
Optional("name"): str,
|
||||
Optional("log_level"): Or("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
|
||||
Optional("collect_metrics"): bool,
|
||||
Optional("collection_name"): str,
|
||||
}
|
||||
},
|
||||
Optional("llm"): {
|
||||
Optional("provider"): Or(
|
||||
"openai",
|
||||
"azure_openai",
|
||||
"anthropic",
|
||||
"huggingface",
|
||||
"cohere",
|
||||
"gpt4all",
|
||||
"jina",
|
||||
"llama2",
|
||||
"vertex_ai",
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
Optional("number_documents"): int,
|
||||
Optional("temperature"): float,
|
||||
Optional("max_tokens"): int,
|
||||
Optional("top_p"): Or(float, int),
|
||||
Optional("stream"): bool,
|
||||
Optional("template"): str,
|
||||
Optional("system_prompt"): str,
|
||||
Optional("deployment_name"): str,
|
||||
Optional("where"): dict,
|
||||
Optional("query_type"): str,
|
||||
},
|
||||
},
|
||||
Optional("vectordb"): {
|
||||
Optional("provider"): Or(
|
||||
"chroma", "elasticsearch", "opensearch", "pinecone", "qdrant", "weaviate", "zilliz"
|
||||
),
|
||||
Optional("config"): {
|
||||
Optional("collection_name"): str,
|
||||
Optional("dir"): str,
|
||||
Optional("allow_reset"): bool,
|
||||
Optional("host"): str,
|
||||
Optional("port"): str,
|
||||
},
|
||||
},
|
||||
Optional("embedder"): {
|
||||
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai"),
|
||||
Optional("config"): {
|
||||
Optional("model"): Optional(str),
|
||||
Optional("deployment_name"): Optional(str),
|
||||
},
|
||||
},
|
||||
Optional("embedding_model"): {
|
||||
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai"),
|
||||
Optional("config"): {
|
||||
Optional("model"): str,
|
||||
Optional("deployment_name"): str,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return schema.validate(config_data)
|
||||
|
||||
Reference in New Issue
Block a user