From a304ded5002ce46e5eeee96fb5980f521c11be69 Mon Sep 17 00:00:00 2001 From: Deven Patel Date: Sat, 30 Dec 2023 16:25:41 +0530 Subject: [PATCH] [Bugfix] fix config validation for google llm config (#1088) Co-authored-by: Deven Patel --- embedchain/embedchain.py | 7 +++++-- embedchain/utils.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 79b0a371..d634f3eb 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -7,7 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union from dotenv import load_dotenv from langchain.docstore.document import Document -from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback +from embedchain.cache import (adapt, get_gptcache_session, + gptcache_data_convert, + gptcache_update_cache_callback) from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config.base_app_config import BaseAppConfig @@ -17,7 +19,8 @@ from embedchain.embedder.base import BaseEmbedder from embedchain.helpers.json_serializable import JSONSerializable from embedchain.llm.base import BaseLlm from embedchain.loaders.base_loader import BaseLoader -from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType +from embedchain.models.data_type import (DataType, DirectDataType, + IndirectDataType, SpecialDataType) from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.vectordb.base import BaseVectorDB diff --git a/embedchain/utils.py b/embedchain/utils.py index f60312d7..60cb32f9 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -420,6 +420,8 @@ def validate_config(config_data): Optional("model"): Optional(str), Optional("deployment_name"): Optional(str), Optional("api_key"): str, + Optional("title"): str, + Optional("task_type"): str, }, }, Optional("embedding_model"): { @@ -428,6 +430,8 @@ def validate_config(config_data): Optional("model"): str, Optional("deployment_name"): str, Optional("api_key"): str, + Optional("title"): str, + Optional("task_type"): str, }, }, Optional("chunker"): {