refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

125
tests/llm/test_chat.py Normal file
View File

@@ -0,0 +1,125 @@
import os
import unittest
from unittest.mock import patch, MagicMock
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.llm.base_llm import BaseLlm
class TestApp(unittest.TestCase):
def setUp(self):
os.environ["OPENAI_API_KEY"] = "test_key"
self.app = App(config=AppConfig(collect_metrics=False))
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
"""
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
memory.
The 'chat' method is called twice. The first call initializes the chat history memory.
The second call is expected to use the chat history from the first call.
Key assumptions tested:
called with correct arguments, adding the correct chat history.
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
- During the second call, the 'chat' method uses the chat history from the first call.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
'memory' methods.
"""
config = AppConfig(collect_metrics=False)
app = App(config=config)
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
self.assertEqual(len(app.llm.history.splitlines()), 2)
second_answer = app.chat("Test query 2")
self.assertEqual(second_answer, "Test answer")
self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
self.assertEqual(len(app.llm.history.splitlines()), 4)
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
@patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
def test_template_replacement(self, mock_get_answer, mock_retrieve):
"""
Tests that if a default template is used and it doesn't contain history,
the default template is swapped in.
Also tests that a dry run does not change the history
"""
config = AppConfig(collect_metrics=False)
app = App(config=config)
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
self.assertEqual(len(app.llm.history.splitlines()), 2)
history = app.llm.history
dry_run = app.chat("Test query 2", dry_run=True)
self.assertIn("History:", dry_run)
self.assertEqual(history, app.llm.history)
self.assertEqual(len(app.llm.history.splitlines()), 2)
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_params(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.chat("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_chat_config(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
queryConfig = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.chat("Test query", queryConfig)
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()

View File

@@ -0,0 +1,70 @@
import unittest
from string import Template
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
class TestGeneratePrompt(unittest.TestCase):
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
def test_generate_prompt_with_template(self):
"""
Tests that the generate_prompt method correctly formats the prompt using
a custom template provided in the QueryConfig instance.
This test sets up a scenario with an input query and a list of contexts,
and a custom template, and then calls generate_prompt. It checks that the
returned prompt correctly incorporates all the contexts and the query into
the format specified by the template.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
config = BaseLlmConfig(template=Template(template))
self.app.llm.config = config
# Execute
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = (
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_contexts_list(self):
"""
Tests that the generate_prompt method correctly handles a list of contexts.
This test sets up a scenario with an input query and a list of contexts,
and then calls generate_prompt. It checks that the returned prompt
correctly includes all the contexts and the query.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
config = BaseLlmConfig()
# Execute
self.app.llm.config = config
result = self.app.llm.generate_prompt(input_query, contexts)
# Assert
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_history(self):
"""
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
"""
config = BaseLlmConfig()
config.template = Template("Context: $context | Query: $query | History: $history")
self.app.llm.config = config
self.app.llm.set_history(["Past context 1", "Past context 2"])
prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
self.assertEqual(prompt, expected_prompt)

147
tests/llm/test_query.py Normal file
View File

@@ -0,0 +1,147 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
_answer = self.app.query(input_query="Test query")
# Ensure retrieve_from_database was called
mock_retrieve.assert_called_once()
# Check the call arguments
args, kwargs = mock_retrieve.call_args
input_query_arg = kwargs.get("input_query")
self.assertEqual(input_query_arg, "Test query")
mock_answer.assert_called_once()
@patch("openai.ChatCompletion.create")
def test_query_config_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig(system_prompt="Test system prompt")
app = App(config=config, llm_config=chat_config)
app.llm.get_llm_model_answer("Test query")
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertTrue(messages_arg[0].get("role"), "system")
self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
self.assertTrue(messages_arg[1].get("role"), "user")
self.assertEqual(messages_arg[1].get("content"), "Test query")
# TODO: Add tests for other config variables
@patch("openai.ChatCompletion.create")
def test_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig(collect_metrics=False)
chat_config = BaseLlmConfig()
app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
app.llm.get_llm_model_answer("Test query")
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertTrue(messages_arg[0].get("role"), "system")
self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_retrieve.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
with patch.object(self.app.db, "query") as mock_database_query:
mock_database_query.return_value = ["Test context"]
queryConfig = BaseLlmConfig(where={"attribute": "value"})
answer = self.app.query("Test query", queryConfig)
self.assertEqual(answer, "Test answer")
_args, kwargs = mock_database_query.call_args
self.assertEqual(kwargs.get('input_query'), "Test query")
self.assertEqual(kwargs.get('where'), {"attribute": "value"})
mock_answer.assert_called_once()