tests: added tests (#250)
This commit is contained in:
42
tests/chunkers/test_text.py
Normal file
42
tests/chunkers/test_text.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from embedchain.chunkers.text import TextChunker
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextChunker(unittest.TestCase):
|
||||||
|
def test_chunks(self):
|
||||||
|
"""
|
||||||
|
Test the chunks generated by TextChunker.
|
||||||
|
# TODO: Not a very precise test.
|
||||||
|
"""
|
||||||
|
chunker_config = {
|
||||||
|
"chunk_size": 10,
|
||||||
|
"chunk_overlap": 0,
|
||||||
|
"length_function": len,
|
||||||
|
}
|
||||||
|
chunker = TextChunker(config=chunker_config)
|
||||||
|
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
|
||||||
|
|
||||||
|
result = chunker.create_chunks(MockLoader(), text)
|
||||||
|
|
||||||
|
documents = result["documents"]
|
||||||
|
|
||||||
|
self.assertGreaterEqual(len(documents), 5)
|
||||||
|
|
||||||
|
# Additional test cases can be added to cover different scenarios
|
||||||
|
|
||||||
|
|
||||||
|
class MockLoader:
|
||||||
|
def load_data(self, src):
|
||||||
|
"""
|
||||||
|
Mock loader that returns a list of data dictionaries.
|
||||||
|
Adjust this method to return different data for testing.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"content": src,
|
||||||
|
"meta_data": {"url": "none"},
|
||||||
|
}
|
||||||
|
]
|
||||||
26
tests/embedchain/test_add.py
Normal file
26
tests/embedchain/test_add.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
|
||||||
|
|
||||||
|
class TestApp(unittest.TestCase):
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.app = App()
|
||||||
|
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
|
def test_add(self):
|
||||||
|
"""
|
||||||
|
This test checks the functionality of the 'add' method in the App class.
|
||||||
|
It begins by simulating the addition of a web page with a specific URL to the application instance.
|
||||||
|
The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance.
|
||||||
|
By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the
|
||||||
|
method is working as intended.
|
||||||
|
The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the
|
||||||
|
'add' method.
|
||||||
|
"""
|
||||||
|
self.app.add("web_page", "https://example.com", {"meta": "meta-data"})
|
||||||
|
self.assertEqual(self.app.user_asks, [["web_page", "https://example.com", {"meta": "meta-data"}]])
|
||||||
48
tests/embedchain/test_chat.py
Normal file
48
tests/embedchain/test_chat.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
|
||||||
|
|
||||||
|
class TestApp(unittest.TestCase):
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.app = App()
|
||||||
|
|
||||||
|
@patch("embedchain.embedchain.memory", autospec=True)
|
||||||
|
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
||||||
|
@patch.object(App, "get_answer_from_llm", return_value="Test answer")
|
||||||
|
def test_chat_with_memory(self, mock_answer, mock_retrieve, mock_memory):
|
||||||
|
"""
|
||||||
|
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
|
||||||
|
memory.
|
||||||
|
The 'chat' method is called twice. The first call initializes the chat history memory.
|
||||||
|
The second call is expected to use the chat history from the first call.
|
||||||
|
|
||||||
|
Key assumptions tested:
|
||||||
|
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
|
||||||
|
called with correct arguments, adding the correct chat history.
|
||||||
|
- During the second call, the 'chat' method uses the chat history from the first call.
|
||||||
|
|
||||||
|
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
|
||||||
|
'memory' methods.
|
||||||
|
"""
|
||||||
|
mock_memory.load_memory_variables.return_value = {"history": []}
|
||||||
|
app = App()
|
||||||
|
|
||||||
|
# First call to chat
|
||||||
|
first_answer = app.chat("Test query 1")
|
||||||
|
self.assertEqual(first_answer, "Test answer")
|
||||||
|
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 1")
|
||||||
|
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")
|
||||||
|
|
||||||
|
mock_memory.chat_memory.add_user_message.reset_mock()
|
||||||
|
mock_memory.chat_memory.add_ai_message.reset_mock()
|
||||||
|
|
||||||
|
# Second call to chat
|
||||||
|
second_answer = app.chat("Test query 2")
|
||||||
|
self.assertEqual(second_answer, "Test answer")
|
||||||
|
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 2")
|
||||||
|
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")
|
||||||
52
tests/embedchain/test_dryrun.py
Normal file
52
tests/embedchain/test_dryrun.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from string import Template
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.embedchain import QueryConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestApp(unittest.TestCase):
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.app = App()
|
||||||
|
|
||||||
|
@patch("logging.info")
|
||||||
|
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
|
||||||
|
"""
|
||||||
|
Test that the 'query' method logs the same prompt as the 'dry_run' method.
|
||||||
|
This is the only way I found to test the prompt in query, that's not returned.
|
||||||
|
"""
|
||||||
|
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||||
|
mock_retrieve.return_value = ["Test context"]
|
||||||
|
input_query = "Test query"
|
||||||
|
config = QueryConfig(
|
||||||
|
number_documents=3,
|
||||||
|
template=Template("Question: $query, context: $context, history: $history"),
|
||||||
|
history=["Past context 1", "Past context 2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(self.app, "get_answer_from_llm"):
|
||||||
|
self.app.dry_run(input_query, config)
|
||||||
|
self.app.query(input_query, config)
|
||||||
|
|
||||||
|
# Access the log messages captured during the execution
|
||||||
|
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
|
||||||
|
|
||||||
|
# Extract the prompts from the log messages
|
||||||
|
dry_run_prompt = self.extract_prompt(logged_messages[0])
|
||||||
|
query_prompt = self.extract_prompt(logged_messages[1])
|
||||||
|
|
||||||
|
# Perform assertions on the prompts
|
||||||
|
self.assertEqual(dry_run_prompt, query_prompt)
|
||||||
|
|
||||||
|
def extract_prompt(self, log_message):
|
||||||
|
"""
|
||||||
|
Extracts the prompt value from the log message.
|
||||||
|
Adjust this method based on the log message format in your implementation.
|
||||||
|
"""
|
||||||
|
# Modify this logic based on your log message format
|
||||||
|
prefix = "Prompt: "
|
||||||
|
return log_message.split(prefix, 1)[1]
|
||||||
39
tests/embedchain/test_embedchain.py
Normal file
39
tests/embedchain/test_embedchain.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.config import InitConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||||
|
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.add")
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.get")
|
||||||
|
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
|
||||||
|
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm")
|
||||||
|
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer")
|
||||||
|
def test_whole_app(
|
||||||
|
self,
|
||||||
|
_mock_get,
|
||||||
|
_mock_add,
|
||||||
|
_mock_ec_retrieve_from_database,
|
||||||
|
_mock_get_answer_from_llm,
|
||||||
|
mock_ec_get_llm_model_answer,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
||||||
|
"""
|
||||||
|
config = InitConfig(log_level="DEBUG")
|
||||||
|
|
||||||
|
app = App(config)
|
||||||
|
|
||||||
|
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"
|
||||||
|
|
||||||
|
app.add_local("text", knowledge)
|
||||||
|
|
||||||
|
app.query("What text did I give you?")
|
||||||
|
app.chat("What text did I give you?")
|
||||||
|
|
||||||
|
self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge])
|
||||||
66
tests/embedchain/test_generate_prompt.py
Normal file
66
tests/embedchain/test_generate_prompt.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import unittest
|
||||||
|
from string import Template
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.embedchain import QueryConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeneratePrompt(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.app = App()
|
||||||
|
|
||||||
|
def test_generate_prompt_with_template(self):
|
||||||
|
"""
|
||||||
|
Tests that the generate_prompt method correctly formats the prompt using
|
||||||
|
a custom template provided in the QueryConfig instance.
|
||||||
|
|
||||||
|
This test sets up a scenario with an input query and a list of contexts,
|
||||||
|
and a custom template, and then calls generate_prompt. It checks that the
|
||||||
|
returned prompt correctly incorporates all the contexts and the query into
|
||||||
|
the format specified by the template.
|
||||||
|
"""
|
||||||
|
# Setup
|
||||||
|
input_query = "Test query"
|
||||||
|
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||||
|
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
|
||||||
|
config = QueryConfig(template=Template(template))
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = self.app.generate_prompt(input_query, contexts, config)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_result = (
|
||||||
|
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_generate_prompt_with_contexts_list(self):
|
||||||
|
"""
|
||||||
|
Tests that the generate_prompt method correctly handles a list of contexts.
|
||||||
|
|
||||||
|
This test sets up a scenario with an input query and a list of contexts,
|
||||||
|
and then calls generate_prompt. It checks that the returned prompt
|
||||||
|
correctly includes all the contexts and the query.
|
||||||
|
"""
|
||||||
|
# Setup
|
||||||
|
input_query = "Test query"
|
||||||
|
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||||
|
config = QueryConfig()
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = self.app.generate_prompt(input_query, contexts, config)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_generate_prompt_with_history(self):
|
||||||
|
"""
|
||||||
|
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
|
||||||
|
"""
|
||||||
|
config = QueryConfig(history=["Past context 1", "Past context 2"])
|
||||||
|
config.template = Template("Context: $context | Query: $query | History: $history")
|
||||||
|
prompt = self.app.generate_prompt("Test query", ["Test context"], config)
|
||||||
|
|
||||||
|
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
|
||||||
|
self.assertEqual(prompt, expected_prompt)
|
||||||
43
tests/embedchain/test_query.py
Normal file
43
tests/embedchain/test_query.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.embedchain import QueryConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestApp(unittest.TestCase):
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.app = App()
|
||||||
|
|
||||||
|
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||||
|
def test_query(self):
|
||||||
|
"""
|
||||||
|
This test checks the functionality of the 'query' method in the App class.
|
||||||
|
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
|
||||||
|
'get_llm_model_answer' returns an expected answer string.
|
||||||
|
|
||||||
|
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
|
||||||
|
appropriately and return the right answer.
|
||||||
|
|
||||||
|
Key assumptions tested:
|
||||||
|
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||||
|
QueryConfig.
|
||||||
|
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||||
|
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||||
|
|
||||||
|
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||||
|
'get_llm_model_answer' methods.
|
||||||
|
"""
|
||||||
|
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||||
|
mock_retrieve.return_value = ["Test context"]
|
||||||
|
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||||
|
mock_answer.return_value = "Test answer"
|
||||||
|
answer = self.app.query("Test query")
|
||||||
|
|
||||||
|
self.assertEqual(answer, "Test answer")
|
||||||
|
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||||
|
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||||
|
mock_answer.assert_called_once()
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
import os
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from embedchain import App
|
|
||||||
|
|
||||||
|
|
||||||
class TestApp(unittest.TestCase):
|
|
||||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.app = App()
|
|
||||||
|
|
||||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
|
||||||
def test_add(self):
|
|
||||||
self.app.add("web_page", "https://example.com")
|
|
||||||
self.assertEqual(self.app.user_asks, [["web_page", "https://example.com"]])
|
|
||||||
|
|
||||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
|
||||||
def test_query(self):
|
|
||||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
|
||||||
mock_retrieve.return_value = "Test context"
|
|
||||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
|
||||||
mock_answer.return_value = "Test answer"
|
|
||||||
answer = self.app.query("Test query")
|
|
||||||
|
|
||||||
self.assertEqual(answer, "Test answer")
|
|
||||||
mock_retrieve.assert_called_once_with("Test query")
|
|
||||||
mock_answer.assert_called_once()
|
|
||||||
72
tests/vectordb/test_chroma_db.py
Normal file
72
tests/vectordb/test_chroma_db.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.config import InitConfig
|
||||||
|
from embedchain.vectordb.chroma_db import ChromaDB, chromadb
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaDbHosts(unittest.TestCase):
|
||||||
|
def test_init_with_host_and_port(self):
|
||||||
|
"""
|
||||||
|
Test if the `ChromaDB` instance is initialized with the correct host and port values.
|
||||||
|
"""
|
||||||
|
host = "test-host"
|
||||||
|
port = "1234"
|
||||||
|
|
||||||
|
with patch.object(chromadb, "Client") as mock_client:
|
||||||
|
_db = ChromaDB(host=host, port=port)
|
||||||
|
|
||||||
|
expected_settings = chromadb.config.Settings(
|
||||||
|
chroma_api_impl="rest",
|
||||||
|
chroma_server_host=host,
|
||||||
|
chroma_server_http_port=port,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called_once_with(expected_settings)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaDbHostsInit(unittest.TestCase):
|
||||||
|
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||||
|
def test_init_with_host_and_port(self, mock_client):
|
||||||
|
"""
|
||||||
|
Test if the `App` instance is initialized with the correct host and port values.
|
||||||
|
"""
|
||||||
|
host = "test-host"
|
||||||
|
port = "1234"
|
||||||
|
|
||||||
|
config = InitConfig(host=host, port=port)
|
||||||
|
|
||||||
|
_app = App(config)
|
||||||
|
|
||||||
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, host)
|
||||||
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, port)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaDbHostsNone(unittest.TestCase):
|
||||||
|
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||||
|
def test_init_with_host_and_port(self, mock_client):
|
||||||
|
"""
|
||||||
|
Test if the `App` instance is initialized without default hosts and ports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_app = App()
|
||||||
|
|
||||||
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
|
||||||
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||||
|
@patch("embedchain.vectordb.chroma_db.chromadb.Client")
|
||||||
|
def test_init_with_host_and_port(self, mock_client):
|
||||||
|
"""
|
||||||
|
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
|
||||||
|
"""
|
||||||
|
config = InitConfig(log_level="DEBUG")
|
||||||
|
|
||||||
|
_app = App(config)
|
||||||
|
|
||||||
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
|
||||||
|
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
|
||||||
Reference in New Issue
Block a user