tests: added tests (#250)

This commit is contained in:
cachho
2023-07-16 02:28:51 +02:00
committed by GitHub
parent d12aeec1ff
commit 3f71050c47
9 changed files with 388 additions and 29 deletions

View File

@@ -0,0 +1,42 @@
# ruff: noqa: E501
import unittest
from embedchain.chunkers.text import TextChunker
class TestTextChunker(unittest.TestCase):
def test_chunks(self):
"""
Test the chunks generated by TextChunker.
# TODO: Not a very precise test.
"""
chunker_config = {
"chunk_size": 10,
"chunk_overlap": 0,
"length_function": len,
}
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
result = chunker.create_chunks(MockLoader(), text)
documents = result["documents"]
self.assertGreaterEqual(len(documents), 5)
# Additional test cases can be added to cover different scenarios
class MockLoader:
def load_data(self, src):
"""
Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing.
"""
return [
{
"content": src,
"meta_data": {"url": "none"},
}
]

View File

@@ -0,0 +1,26 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_add(self):
"""
This test checks the functionality of the 'add' method in the App class.
It begins by simulating the addition of a web page with a specific URL to the application instance.
The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance.
By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the
method is working as intended.
The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the
'add' method.
"""
self.app.add("web_page", "https://example.com", {"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["web_page", "https://example.com", {"meta": "meta-data"}]])

View File

@@ -0,0 +1,48 @@
import os
import unittest
from unittest.mock import patch
from embedchain import App
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App()
@patch("embedchain.embedchain.memory", autospec=True)
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
@patch.object(App, "get_answer_from_llm", return_value="Test answer")
def test_chat_with_memory(self, mock_answer, mock_retrieve, mock_memory):
"""
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
memory.
The 'chat' method is called twice. The first call initializes the chat history memory.
The second call is expected to use the chat history from the first call.
Key assumptions tested:
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
called with correct arguments, adding the correct chat history.
- During the second call, the 'chat' method uses the chat history from the first call.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
'memory' methods.
"""
mock_memory.load_memory_variables.return_value = {"history": []}
app = App()
# First call to chat
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 1")
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")
mock_memory.chat_memory.add_user_message.reset_mock()
mock_memory.chat_memory.add_ai_message.reset_mock()
# Second call to chat
second_answer = app.chat("Test query 2")
self.assertEqual(second_answer, "Test answer")
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 2")
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")

View File

@@ -0,0 +1,52 @@
import os
import unittest
from string import Template
from unittest.mock import patch
from embedchain import App
from embedchain.embedchain import QueryConfig
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App()
@patch("logging.info")
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
"""
Test that the 'query' method logs the same prompt as the 'dry_run' method.
This is the only way I found to test the prompt in query, that's not returned.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
input_query = "Test query"
config = QueryConfig(
number_documents=3,
template=Template("Question: $query, context: $context, history: $history"),
history=["Past context 1", "Past context 2"],
)
with patch.object(self.app, "get_answer_from_llm"):
self.app.dry_run(input_query, config)
self.app.query(input_query, config)
# Access the log messages captured during the execution
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
# Extract the prompts from the log messages
dry_run_prompt = self.extract_prompt(logged_messages[0])
query_prompt = self.extract_prompt(logged_messages[1])
# Perform assertions on the prompts
self.assertEqual(dry_run_prompt, query_prompt)
def extract_prompt(self, log_message):
"""
Extracts the prompt value from the log message.
Adjust this method based on the log message format in your implementation.
"""
# Modify this logic based on your log message format
prefix = "Prompt: "
return log_message.split(prefix, 1)[1]

View File

@@ -0,0 +1,39 @@
import os
import unittest
from unittest.mock import patch
from embedchain import App
from embedchain.config import InitConfig
class TestChromaDbHostsLoglevel(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
@patch("chromadb.api.models.Collection.Collection.add")
@patch("chromadb.api.models.Collection.Collection.get")
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm")
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer")
def test_whole_app(
self,
_mock_get,
_mock_add,
_mock_ec_retrieve_from_database,
_mock_get_answer_from_llm,
mock_ec_get_llm_model_answer,
):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
config = InitConfig(log_level="DEBUG")
app = App(config)
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
app.add_local("text", knowledge)
app.query("What text did I give you?")
app.chat("What text did I give you?")
self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge])

View File

@@ -0,0 +1,66 @@
import unittest
from string import Template
from embedchain import App
from embedchain.embedchain import QueryConfig
class TestGeneratePrompt(unittest.TestCase):
def setUp(self):
self.app = App()
def test_generate_prompt_with_template(self):
"""
Tests that the generate_prompt method correctly formats the prompt using
a custom template provided in the QueryConfig instance.
This test sets up a scenario with an input query and a list of contexts,
and a custom template, and then calls generate_prompt. It checks that the
returned prompt correctly incorporates all the contexts and the query into
the format specified by the template.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
config = QueryConfig(template=Template(template))
# Execute
result = self.app.generate_prompt(input_query, contexts, config)
# Assert
expected_result = (
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_contexts_list(self):
"""
Tests that the generate_prompt method correctly handles a list of contexts.
This test sets up a scenario with an input query and a list of contexts,
and then calls generate_prompt. It checks that the returned prompt
correctly includes all the contexts and the query.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
config = QueryConfig()
# Execute
result = self.app.generate_prompt(input_query, contexts, config)
# Assert
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_history(self):
"""
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
"""
config = QueryConfig(history=["Past context 1", "Past context 2"])
config.template = Template("Context: $context | Query: $query | History: $history")
prompt = self.app.generate_prompt("Test query", ["Test context"], config)
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
self.assertEqual(prompt, expected_prompt)

View File

@@ -0,0 +1,43 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.embedchain import QueryConfig
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query")
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
mock_answer.assert_called_once()

View File

@@ -1,29 +0,0 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_add(self):
self.app.add("web_page", "https://example.com")
self.assertEqual(self.app.user_asks, [["web_page", "https://example.com"]])
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self):
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = "Test context"
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query")
self.assertEqual(answer, "Test answer")
mock_retrieve.assert_called_once_with("Test query")
mock_answer.assert_called_once()

View File

@@ -0,0 +1,72 @@
# ruff: noqa: E501
import unittest
from unittest.mock import patch
from embedchain import App
from embedchain.config import InitConfig
from embedchain.vectordb.chroma_db import ChromaDB, chromadb
class TestChromaDbHosts(unittest.TestCase):
def test_init_with_host_and_port(self):
"""
Test if the `ChromaDB` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
with patch.object(chromadb, "Client") as mock_client:
_db = ChromaDB(host=host, port=port)
expected_settings = chromadb.config.Settings(
chroma_api_impl="rest",
chroma_server_host=host,
chroma_server_http_port=port,
)
mock_client.assert_called_once_with(expected_settings)
class TestChromaDbHostsInit(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
"""
Test if the `App` instance is initialized with the correct host and port values.
"""
host = "test-host"
port = "1234"
config = InitConfig(host=host, port=port)
_app = App(config)
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
class TestChromaDbHostsNone(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
"""
Test if the `App` instance is initialized without default hosts and ports.
"""
_app = App()
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
class TestChromaDbHostsLoglevel(unittest.TestCase):
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
def test_init_with_host_and_port(self, mock_client):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
config = InitConfig(log_level="DEBUG")
_app = App(config)
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)