[Feature]: Add support for creating app using yaml config (#787)
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import yaml
|
||||
|
||||
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
|
||||
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
|
||||
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
|
||||
BaseLlmConfig, ChromaDbConfig)
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
|
||||
@@ -100,3 +103,74 @@ class TestConfigForAppComponents(unittest.TestCase):
|
||||
App(embedder_config=wrong_embedder_config)
|
||||
|
||||
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
|
||||
|
||||
|
||||
class TestAppFromConfig:
|
||||
def load_config_data(self, yaml_path):
|
||||
with open(yaml_path, "r") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
def test_from_chroma_config(self):
|
||||
yaml_path = "embedchain/yaml/chroma.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
# Even though not present in the config, the default value is used
|
||||
assert app.config.collect_metrics is True
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
|
||||
def test_from_opensource_config(self):
|
||||
yaml_path = "embedchain/yaml/opensource.yaml"
|
||||
config_data = self.load_config_data(yaml_path)
|
||||
|
||||
app = App.from_config(yaml_path)
|
||||
|
||||
# Check if the App instance and its components were created correctly
|
||||
assert isinstance(app, App)
|
||||
|
||||
# Validate the AppConfig values
|
||||
assert app.config.id == config_data["app"]["config"]["id"]
|
||||
assert app.config.collection_name == config_data["app"]["config"]["collection_name"]
|
||||
assert app.config.collect_metrics == config_data["app"]["config"]["collect_metrics"]
|
||||
|
||||
# Validate the LLM config values
|
||||
llm_config = config_data["llm"]["config"]
|
||||
assert app.llm.config.temperature == llm_config["temperature"]
|
||||
assert app.llm.config.max_tokens == llm_config["max_tokens"]
|
||||
assert app.llm.config.top_p == llm_config["top_p"]
|
||||
assert app.llm.config.stream == llm_config["stream"]
|
||||
|
||||
# Validate the VectorDB config values
|
||||
db_config = config_data["vectordb"]["config"]
|
||||
assert app.db.config.collection_name == db_config["collection_name"]
|
||||
assert app.db.config.dir == db_config["dir"]
|
||||
assert app.db.config.allow_reset == db_config["allow_reset"]
|
||||
|
||||
# Validate the Embedder config values
|
||||
embedder_config = config_data["embedder"]["config"]
|
||||
assert app.embedder.config.model == embedder_config["model"]
|
||||
assert app.embedder.config.deployment_name == embedder_config["deployment_name"]
|
||||
|
||||
Reference in New Issue
Block a user