Improve tests (#795)

This commit is contained in:
Sidharth Mohanty
2023-10-13 01:45:22 +05:30
committed by GitHub
parent b5de605e2b
commit 4820ea15d6
7 changed files with 373 additions and 19 deletions

View File

@@ -2,18 +2,15 @@ import os
import unittest
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import ChromaDbConfig
from embedchain.config import ChromaDbConfig, AppConfig, AddConfig, BaseLlmConfig, BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
from embedchain.vectordb.chroma import ChromaDB
class TestApps(unittest.TestCase):
try:
del os.environ["OPENAI_KEY"]
except KeyError:
pass
os.environ["OPENAI_API_KEY"] = "test_api_key"
def test_app(self):
app = App()
@@ -21,6 +18,18 @@ class TestApps(unittest.TestCase):
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
wrong_llm = "wrong_llm"
with self.assertRaises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with self.assertRaises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with self.assertRaises(TypeError):
App(embedder=wrong_embedder)
def test_custom_app(self):
app = CustomApp()
self.assertIsInstance(app.llm, BaseLlm)
@@ -58,3 +67,36 @@ class TestConfigForAppComponents(unittest.TestCase):
database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
app = App(db=database)
self.assertEqual(app.db.config.collection_name, self.collection_name)
def test_different_configs_are_proper_instances(self):
config = AppConfig()
wrong_app_config = AddConfig()
with self.assertRaises(TypeError):
App(config=wrong_app_config)
self.assertIsInstance(config, AppConfig)
llm_config = BaseLlmConfig()
wrong_llm_config = "wrong_llm_config"
with self.assertRaises(TypeError):
App(llm_config=wrong_llm_config)
self.assertIsInstance(llm_config, BaseLlmConfig)
db_config = BaseVectorDbConfig()
wrong_db_config = "wrong_db_config"
with self.assertRaises(TypeError):
App(db_config=wrong_db_config)
self.assertIsInstance(db_config, BaseVectorDbConfig)
embedder_config = BaseEmbedderConfig()
wrong_embedder_config = "wrong_embedder_config"
with self.assertRaises(TypeError):
App(embedder_config=wrong_embedder_config)
self.assertIsInstance(embedder_config, BaseEmbedderConfig)