Improve tests (#800)

This commit is contained in:
Sidharth Mohanty
2023-10-15 07:46:27 +05:30
committed by GitHub
parent 77c90a308e
commit 5ec12212e4
6 changed files with 287 additions and 130 deletions

View File

@@ -1,108 +1,106 @@
import os
import unittest
import pytest
import yaml
from embedchain import App, CustomApp, Llama2App, OpenSourceApp
from embedchain.config import (AddConfig, AppConfig, BaseEmbedderConfig,
BaseLlmConfig, ChromaDbConfig)
from embedchain.config import AddConfig, AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChromaDbConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.llm.base import BaseLlm
from embedchain.vectordb.base import BaseVectorDB, BaseVectorDbConfig
from embedchain.vectordb.chroma import ChromaDB
class TestApps(unittest.TestCase):
@pytest.fixture
def app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
def test_app(self):
app = App()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
wrong_llm = "wrong_llm"
with self.assertRaises(TypeError):
App(llm=wrong_llm)
wrong_db = "wrong_db"
with self.assertRaises(TypeError):
App(db=wrong_db)
wrong_embedder = "wrong_embedder"
with self.assertRaises(TypeError):
App(embedder=wrong_embedder)
def test_custom_app(self):
app = CustomApp()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
def test_opensource_app(self):
app = OpenSourceApp()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
def test_llama2_app(self):
os.environ["REPLICATE_API_TOKEN"] = "-"
app = Llama2App()
self.assertIsInstance(app.llm, BaseLlm)
self.assertIsInstance(app.db, BaseVectorDB)
self.assertIsInstance(app.embedder, BaseEmbedder)
return App()
class TestConfigForAppComponents(unittest.TestCase):
collection_name = "my-test-collection"
@pytest.fixture
def custom_app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
return CustomApp()
@pytest.fixture
def opensource_app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
return OpenSourceApp()
@pytest.fixture
def llama2_app():
os.environ["OPENAI_API_KEY"] = "test_api_key"
os.environ["REPLICATE_API_TOKEN"] = "-"
return Llama2App()
def test_app(app):
assert isinstance(app.llm, BaseLlm)
assert isinstance(app.db, BaseVectorDB)
assert isinstance(app.embedder, BaseEmbedder)
def test_custom_app(custom_app):
assert isinstance(custom_app.llm, BaseLlm)
assert isinstance(custom_app.db, BaseVectorDB)
assert isinstance(custom_app.embedder, BaseEmbedder)
def test_opensource_app(opensource_app):
assert isinstance(opensource_app.llm, BaseLlm)
assert isinstance(opensource_app.db, BaseVectorDB)
assert isinstance(opensource_app.embedder, BaseEmbedder)
def test_llama2_app(llama2_app):
assert isinstance(llama2_app.llm, BaseLlm)
assert isinstance(llama2_app.db, BaseVectorDB)
assert isinstance(llama2_app.embedder, BaseEmbedder)
class TestConfigForAppComponents:
def test_constructor_config(self):
"""
Test that app can be configured through the app constructor.
"""
app = App(db_config=ChromaDbConfig(collection_name=self.collection_name))
self.assertEqual(app.db.config.collection_name, self.collection_name)
collection_name = "my-test-collection"
app = App(db_config=ChromaDbConfig(collection_name=collection_name))
assert app.db.config.collection_name == collection_name
def test_component_config(self):
"""
Test that app can also be configured by passing a configured component to the app
"""
database = ChromaDB(config=ChromaDbConfig(collection_name=self.collection_name))
collection_name = "my-test-collection"
database = ChromaDB(config=ChromaDbConfig(collection_name=collection_name))
app = App(db=database)
self.assertEqual(app.db.config.collection_name, self.collection_name)
assert app.db.config.collection_name == collection_name
def test_different_configs_are_proper_instances(self):
config = AppConfig()
wrong_app_config = AddConfig()
app_config = AppConfig()
wrong_config = AddConfig()
with pytest.raises(TypeError):
App(config=wrong_config)
with self.assertRaises(TypeError):
App(config=wrong_app_config)
self.assertIsInstance(config, AppConfig)
assert isinstance(app_config, AppConfig)
llm_config = BaseLlmConfig()
wrong_llm_config = "wrong_llm_config"
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
App(llm_config=wrong_llm_config)
self.assertIsInstance(llm_config, BaseLlmConfig)
assert isinstance(llm_config, BaseLlmConfig)
db_config = BaseVectorDbConfig()
wrong_db_config = "wrong_db_config"
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
App(db_config=wrong_db_config)
self.assertIsInstance(db_config, BaseVectorDbConfig)
assert isinstance(db_config, BaseVectorDbConfig)
embedder_config = BaseEmbedderConfig()
wrong_embedder_config = "wrong_embedder_config"
with self.assertRaises(TypeError):
with pytest.raises(TypeError):
App(embedder_config=wrong_embedder_config)
self.assertIsInstance(embedder_config, BaseEmbedderConfig)
assert isinstance(embedder_config, BaseEmbedderConfig)
class TestAppFromConfig:

View File

@@ -0,0 +1,80 @@
import pytest
from embedchain.apps.app import App
from embedchain.apps.person_app import PersonApp, PersonOpenSourceApp
from embedchain.config import BaseLlmConfig, AppConfig
from embedchain.config.llm.base import DEFAULT_PROMPT
@pytest.fixture
def person_app():
config = AppConfig()
return PersonApp("John Doe", config)
@pytest.fixture
def opensource_person_app():
config = AppConfig()
return PersonOpenSourceApp("John Doe", config)
def test_person_app_initialization(person_app):
assert person_app.person == "John Doe"
assert f"You are {person_app.person}" in person_app.person_prompt
assert isinstance(person_app.config, AppConfig)
def test_person_app_add_person_template_to_config_with_invalid_template():
app = PersonApp("John Doe")
default_prompt = "Input Prompt"
with pytest.raises(ValueError):
# as prompt doesn't contain $context and $query
app.add_person_template_to_config(default_prompt)
def test_person_app_add_person_template_to_config_with_valid_template():
app = PersonApp("John Doe")
config = app.add_person_template_to_config(DEFAULT_PROMPT)
assert (
config.template.template
== f"You are John Doe. Whatever you say, you will always say in John Doe style. {DEFAULT_PROMPT}"
)
def test_person_app_query(mocker, person_app):
input_query = "Hello, how are you?"
config = BaseLlmConfig()
mocker.patch.object(App, "query", return_value="Mocked response")
result = person_app.query(input_query, config)
assert result == "Mocked response"
def test_person_app_chat(mocker, person_app):
input_query = "Hello, how are you?"
config = BaseLlmConfig()
mocker.patch.object(App, "chat", return_value="Mocked chat response")
result = person_app.chat(input_query, config)
assert result == "Mocked chat response"
def test_opensource_person_app_query(mocker, opensource_person_app):
input_query = "Hello, how are you?"
config = BaseLlmConfig()
mocker.patch.object(App, "query", return_value="Mocked response")
result = opensource_person_app.query(input_query, config)
assert result == "Mocked response"
def test_opensource_person_app_chat(mocker, opensource_person_app):
input_query = "Hello, how are you?"
config = BaseLlmConfig()
mocker.patch.object(App, "chat", return_value="Mocked chat response")
result = opensource_person_app.chat(input_query, config)
assert result == "Mocked chat response"

50
tests/bots/test_poe.py Normal file
View File

@@ -0,0 +1,50 @@
import argparse
import pytest
from embedchain.bots.poe import PoeBot, start_command
from fastapi_poe.types import QueryRequest, ProtocolMessage
@pytest.fixture
def poe_bot(mocker):
bot = PoeBot()
mocker.patch("fastapi_poe.run")
return bot
@pytest.mark.asyncio
async def test_poe_bot_get_response(poe_bot, mocker):
query = QueryRequest(
version="test",
type="query",
query=[ProtocolMessage(role="system", content="Test content")],
user_id="test_user_id",
conversation_id="test_conversation_id",
message_id="test_message_id",
)
mocker.patch.object(poe_bot.app.llm, "set_history")
response_generator = poe_bot.get_response(query)
await response_generator.__anext__()
poe_bot.app.llm.set_history.assert_called_once()
def test_poe_bot_handle_message(poe_bot, mocker):
mocker.patch.object(poe_bot, "ask_bot", return_value="Answer from the bot")
response_ask = poe_bot.handle_message("What is the answer?")
assert response_ask == "Answer from the bot"
# TODO: This test will fail because the add_data method is commented out.
# mocker.patch.object(poe_bot, 'add_data', return_value="Added data from: some_data")
# response_add = poe_bot.handle_message("/add some_data")
# assert response_add == "Added data from: some_data"
def test_start_command(mocker):
mocker.patch("argparse.ArgumentParser.parse_args", return_value=argparse.Namespace(api_key="test_api_key"))
mocker.patch("embedchain.bots.poe.run")
start_command()

View File

@@ -1,70 +1,49 @@
import os
import unittest
from unittest.mock import MagicMock, patch
import pytest
from embedchain import App
from embedchain.config import AddConfig, AppConfig, ChunkerConfig
from embedchain.models.data_type import DataType
os.environ["OPENAI_API_KEY"] = "test_key"
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
@pytest.fixture
def app(mocker):
mocker.patch("chromadb.api.models.Collection.Collection.add")
return App(config=AppConfig(collect_metrics=False))
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_add(self):
"""
This test checks the functionality of the 'add' method in the App class.
It begins by simulating the addition of a web page with a specific URL to the application instance.
The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance.
By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the
method is working as intended.
The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the
'add' method.
"""
self.app.add("https://example.com", metadata={"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["https://example.com", "web_page", {"meta": "meta-data"}]])
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_add_sitemap(self):
"""
In addition to the test_add function, this test checks that sitemaps can be added with the correct data type.
"""
self.app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]])
def test_add(app):
app.add("https://example.com", metadata={"meta": "meta-data"})
assert app.user_asks == [["https://example.com", "web_page", {"meta": "meta-data"}]]
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_add_forced_type(self):
"""
Test that you can also force a data_type with `add`.
"""
data_type = "text"
self.app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["https://example.com", data_type, {"meta": "meta-data"}]])
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_dry_run(self):
"""
Test that if dry_run == True then data chunks are returned.
"""
def test_add_sitemap(app):
app.add("https://www.google.com/sitemap.xml", metadata={"meta": "meta-data"})
assert app.user_asks == [["https://www.google.com/sitemap.xml", "sitemap", {"meta": "meta-data"}]]
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
# We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
result = self.app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
def test_add_forced_type(app):
data_type = "text"
app.add("https://example.com", data_type=data_type, metadata={"meta": "meta-data"})
assert app.user_asks == [["https://example.com", data_type, {"meta": "meta-data"}]]
chunks = result["chunks"]
metadata = result["metadata"]
count = result["count"]
data_type = result["type"]
self.assertEqual(len(chunks), len(text))
self.assertEqual(count, len(text))
self.assertEqual(data_type, DataType.TEXT)
for item in metadata:
self.assertIsInstance(item, dict)
self.assertIn(item["url"], "local")
self.assertIn(item["data_type"], "text")
def test_dry_run(app):
chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
chunks = result["chunks"]
metadata = result["metadata"]
count = result["count"]
data_type = result["type"]
assert len(chunks) == len(text)
assert count == len(text)
assert data_type == DataType.TEXT
for item in metadata:
assert isinstance(item, dict)
assert "local" in item["url"]
assert "text" in item["data_type"]

View File

@@ -0,0 +1,52 @@
import pytest
from embedchain.llm.base import BaseLlm, BaseLlmConfig
@pytest.fixture
def base_llm():
config = BaseLlmConfig()
return BaseLlm(config=config)
def test_is_get_llm_model_answer_not_implemented(base_llm):
with pytest.raises(NotImplementedError):
base_llm.get_llm_model_answer()
def test_is_get_llm_model_answer_implemented():
class TestLlm(BaseLlm):
def get_llm_model_answer(self):
return "Implemented"
config = BaseLlmConfig()
llm = TestLlm(config=config)
assert llm.get_llm_model_answer() == "Implemented"
def test_stream_query_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_query_response(answer))
assert result == answer
def test_stream_chat_response(base_llm):
answer = ["Chunk1", "Chunk2", "Chunk3"]
result = list(base_llm._stream_chat_response(answer))
assert result == answer
def test_append_search_and_context(base_llm):
context = "Context"
web_search_result = "Web Search Result"
result = base_llm._append_search_and_context(context, web_search_result)
expected_result = "Context\nWeb Search Result: Web Search Result"
assert result == expected_result
def test_access_search_and_get_results(base_llm, mocker):
base_llm.access_search_and_get_results = mocker.patch.object(
base_llm, "access_search_and_get_results", return_value="Search Results"
)
input_query = "Test query"
result = base_llm.access_search_and_get_results(input_query)
assert result == "Search Results"