Improve tests (#800)

This commit is contained in:
Sidharth Mohanty
2023-10-15 07:46:27 +05:30
committed by GitHub
parent 77c90a308e
commit 5ec12212e4
6 changed files with 287 additions and 130 deletions

View File

@@ -1,108 +1,106 @@
import os
import unittest
import pytest
import yaml
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
from embedchain.vectordb.chroma import ChromaDB
class TestApps(unittest.TestCase):
@pytest.fixture
def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
def test_app(self):
app = App()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
wrong_llm = "wrong_llm"
with self.assertRaises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with self.assertRaises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with self.assertRaises(TypeError):
App(embedder=wrong_embedder)
def test_custom_app(self):
app = CustomApp()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
def test_opensource_app(self):
app = OpenSourceApp()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
def test_llama2_app(self):
os.environ["REPLICATE_API_TOKEN"] = "-"
app = Llama2App()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
return App()
class TestConfigForAppComponents(unittest.TestCase):
collection_name = "my-test-collection"
@pytest.fixture
def custom_app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
return CustomApp()
@pytest.fixture
def opensource_app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
return OpenSourceApp()
@pytest.fixture
def llama2_app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["REPLICATE_API_TOKEN"] = "-"
return Llama2App()
def test_app(app):
assert isinstance(app.llm, BaseLlm)
assert isinstance(app.db, BaseVectorDB)
assert isinstance(app.embedder, BaseEmbedder)
def test_custom_app(custom_app):
assert isinstance(custom_app.llm, BaseLlm)
assert isinstance(custom_app.db, BaseVectorDB)
assert isinstance(custom_app.embedder, BaseEmbedder)
def test_opensource_app(opensource_app):
assert isinstance(opensource_app.llm, BaseLlm)
assert isinstance(opensource_app.db, BaseVectorDB)
assert isinstance(opensource_app.embedder, BaseEmbedder)
def test_llama2_app(llama2_app):
assert isinstance(llama2_app.llm, BaseLlm)
assert isinstance(llama2_app.db, BaseVectorDB)
assert isinstance(llama2_app.embedder, BaseEmbedder)
class TestConfigForAppComponents:
def test_constructor_config(self):
"""
Test that app can be configured through the app constructor.
"""
app = App(db_config=ChromaDbConfig(collection_name=self.collection_name))
self.assertEqual(app.db.config.collection_name, self.collection_name)
collection_name = "my-test-collection"
app = App(db_config=ChromaDbConfig(collection_name=collection_name))
assert app.db.config.collection_name == collection_name
def test_component_config(self):
"""
Test that app can also be configured by passing a configured component to the app
"""
database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
collection_name = "my-test-collection"
database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
app = App(db=database)
self.assertEqual(app.db.config.collection_name, self.collection_name)
assert app.db.config.collection_name == collection_name
def test_different_configs_are_proper_instances(self):
config = AppConfig()
wrong_app_config = AddConfig()
app_config = AppConfig()
wrong_config = AddConfig()
with pytest.raises(TypeError):
App(config=wrong_config)
with self.assertRaises(TypeError):
App(config=wrong_app_config)
self.assertIsInstance(config, AppConfig)
assert isinstance(app_config, AppConfig)
llm_config = BaseLlmConfig()
wrong_llm_config = "wrong_llm_config"
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
App(llm_config=wrong_llm_config)
self.assertIsInstance(llm_config, BaseLlmConfig)
assert isinstance(llm_config, BaseLlmConfig)
db_config = BaseVectorDbConfig()
wrong_db_config = "wrong_db_config"
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
App(db_config=wrong_db_config)
self.assertIsInstance(db_config, BaseVectorDbConfig)
assert isinstance(db_config, BaseVectorDbConfig)
embedder_config = BaseEmbedderConfig()
wrong_embedder_config = "wrong_embedder_config"
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
App(embedder_config=wrong_embedder_config)
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
assert isinstance(embedder_config, BaseEmbedderConfig)
class TestAppFromConfig: