Add yaml config validation (#890)

This commit is contained in:
Sidharth Mohanty
2023-11-05 10:53:55 +05:30
committed by GitHub
parent 5428765329
commit 830a7397ef
7 changed files with 206 additions and 19 deletions

View File

@@ -11,6 +11,7 @@ from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.utils import validate_yaml_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@@ -127,6 +128,11 @@ class App(EmbedChain):
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
try:
validate_yaml_config(config_data)
except Exception as e:
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
app_config_data = config_data.get("app", {})
llm_config_data = config_data.get("llm", {})
db_config_data = config_data.get("vectordb", {})

View File

@@ -19,6 +19,7 @@ from embedchain.helper.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm
from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import validate_yaml_config
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.chroma import ChromaDB
@@ -357,6 +358,11 @@ class Pipeline(EmbedChain):
with open(yaml_path, "r") as file:
config_data = yaml.safe_load(file)
try:
validate_yaml_config(config_data)
except Exception as e:
raise Exception(f"❌ Error occurred while validating the YAML config. Error: {str(e)}")
pipeline_config_data = config_data.get("app", {}).get("config", {})
db_config_data = config_data.get("vectordb", {})
embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))

View File

@@ -5,6 +5,8 @@ import re
import string
from typing import Any
from schema import Optional, Or, Schema
from embedchain.models.data_type import DataType
@@ -136,8 +138,7 @@ def detect_datatype(source: Any) -> DataType:
formatted_source = format_source(str(source), 30)
if url:
from langchain.document_loaders.youtube import \
ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -283,3 +284,73 @@ def is_valid_json_string(source: str):
Check the docs to see the supported formats - `https://docs.embedchain.ai/data-sources/json`"
)
return False
def validate_yaml_config(config_data):
schema = Schema(
{
Optional("app"): {
Optional("config"): {
Optional("id"): str,
Optional("name"): str,
Optional("log_level"): Or("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"),
Optional("collect_metrics"): bool,
Optional("collection_name"): str,
}
},
Optional("llm"): {
Optional("provider"): Or(
"openai",
"azure_openai",
"anthropic",
"huggingface",
"cohere",
"gpt4all",
"jina",
"llama2",
"vertex_ai",
),
Optional("config"): {
Optional("model"): str,
Optional("number_documents"): int,
Optional("temperature"): float,
Optional("max_tokens"): int,
Optional("top_p"): Or(float, int),
Optional("stream"): bool,
Optional("template"): str,
Optional("system_prompt"): str,
Optional("deployment_name"): str,
Optional("where"): dict,
Optional("query_type"): str,
},
},
Optional("vectordb"): {
Optional("provider"): Or(
"chroma", "elasticsearch", "opensearch", "pinecone", "qdrant", "weaviate", "zilliz"
),
Optional("config"): {
Optional("collection_name"): str,
Optional("dir"): str,
Optional("allow_reset"): bool,
Optional("host"): str,
Optional("port"): str,
},
},
Optional("embedder"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai"),
Optional("config"): {
Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str),
},
},
Optional("embedding_model"): {
Optional("provider"): Or("openai", "gpt4all", "huggingface", "vertexai"),
Optional("config"): {
Optional("model"): str,
Optional("deployment_name"): str,
},
},
}
)
return schema.validate(config_data)