[Bugfix] fix config validation for google llm config (#1088)

Co-authored-by: Deven Patel <deven298@yahoo.com>
This commit is contained in:
Deven Patel
2023-12-30 16:25:41 +05:30
committed by GitHub
parent a54dde0509
commit a304ded500
2 changed files with 9 additions and 2 deletions

View File

@@ -7,7 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.docstore.document import Document from langchain.docstore.document import Document
from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback from embedchain.cache import (adapt, get_gptcache_session,
gptcache_data_convert,
gptcache_update_cache_callback)
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
from embedchain.config.base_app_config import BaseAppConfig from embedchain.config.base_app_config import BaseAppConfig
@@ -17,7 +19,8 @@ from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType from embedchain.models.data_type import (DataType, DirectDataType,
IndirectDataType, SpecialDataType)
from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.telemetry.posthog import AnonymousTelemetry
from embedchain.utils import detect_datatype, is_valid_json_string from embedchain.utils import detect_datatype, is_valid_json_string
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -420,6 +420,8 @@ def validate_config(config_data):
Optional("model"): Optional(str), Optional("model"): Optional(str),
Optional("deployment_name"): Optional(str), Optional("deployment_name"): Optional(str),
Optional("api_key"): str, Optional("api_key"): str,
Optional("title"): str,
Optional("task_type"): str,
}, },
}, },
Optional("embedding_model"): { Optional("embedding_model"): {
@@ -428,6 +430,8 @@ def validate_config(config_data):
Optional("model"): str, Optional("model"): str,
Optional("deployment_name"): str, Optional("deployment_name"): str,
Optional("api_key"): str, Optional("api_key"): str,
Optional("title"): str,
Optional("task_type"): str,
}, },
}, },
Optional("chunker"): { Optional("chunker"): {