refactor: classes and configs (#528)

This commit is contained in:
cachho
2023-09-05 10:12:58 +02:00
committed by GitHub
parent 387b042a49
commit 344e7470f6
50 changed files with 1221 additions and 997 deletions

View File

@@ -1,99 +0,0 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, ChatConfig
class TestApp(unittest.TestCase):
def setUp(self):
os.environ["OPENAI_API_KEY"] = "test_key"
self.app = App(config=AppConfig(collect_metrics=False))
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
@patch.object(App, "get_answer_from_llm", return_value="Test answer")
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
"""
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
memory.
The 'chat' method is called twice. The first call initializes the chat history memory.
The second call is expected to use the chat history from the first call.
Key assumptions tested:
called with correct arguments, adding the correct chat history.
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
- During the second call, the 'chat' method uses the chat history from the first call.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
'memory' methods.
"""
app = App()
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
self.assertEqual(len(app.memory.chat_memory.messages), 2)
second_answer = app.chat("Test query 2")
self.assertEqual(second_answer, "Test answer")
self.assertEqual(len(app.memory.chat_memory.messages), 4)
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_params(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.chat("Test chat", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_chat_with_where_in_chat_config(self):
"""
This test checks the functionality of the 'chat' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'chat' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
chatConfig = ChatConfig(where={"attribute": "value"})
answer = self.app.chat("Test chat", chatConfig)
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
mock_answer.assert_called_once()

View File

@@ -3,8 +3,7 @@ import unittest
from unittest.mock import patch
from embedchain import App
from embedchain.config import AppConfig, CustomAppConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.config import AppConfig, ChromaDbConfig
class TestChromaDbHostsLoglevel(unittest.TestCase):
@@ -13,8 +12,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
@patch("chromadb.api.models.Collection.Collection.add")
@patch("chromadb.api.models.Collection.Collection.get")
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm")
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer")
@patch("embedchain.llm.base_llm.BaseLlm.get_answer_from_llm")
@patch("embedchain.llm.base_llm.BaseLlm.get_llm_model_answer")
def test_whole_app(
self,
_mock_get,
@@ -43,17 +42,14 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
"""
Test if the `App` instance is correctly reconstructed after a reset.
"""
app = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
config = AppConfig(log_level="DEBUG", collect_metrics=False)
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
app.reset()
# Make sure the client is still healthy
app.db.client.heartbeat()
# Make sure the collection exists, and can be added to
app.collection.add(
app.db.collection.add(
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
metadatas=[
{"chapter": "3", "verse": "16"},

View File

@@ -1,66 +0,0 @@
import unittest
from string import Template
from embedchain import App
from embedchain.config import AppConfig, QueryConfig
class TestGeneratePrompt(unittest.TestCase):
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
def test_generate_prompt_with_template(self):
"""
Tests that the generate_prompt method correctly formats the prompt using
a custom template provided in the QueryConfig instance.
This test sets up a scenario with an input query and a list of contexts,
and a custom template, and then calls generate_prompt. It checks that the
returned prompt correctly incorporates all the contexts and the query into
the format specified by the template.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
config = QueryConfig(template=Template(template))
# Execute
result = self.app.generate_prompt(input_query, contexts, config)
# Assert
expected_result = (
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_contexts_list(self):
"""
Tests that the generate_prompt method correctly handles a list of contexts.
This test sets up a scenario with an input query and a list of contexts,
and then calls generate_prompt. It checks that the returned prompt
correctly includes all the contexts and the query.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
config = QueryConfig()
# Execute
result = self.app.generate_prompt(input_query, contexts, config)
# Assert
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)
def test_generate_prompt_with_history(self):
"""
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
"""
config = QueryConfig(history=["Past context 1", "Past context 2"])
config.template = Template("Context: $context | Query: $query | History: $history")
prompt = self.app.generate_prompt("Test query", ["Test context"], config)
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
self.assertEqual(prompt, expected_prompt)

View File

@@ -1,137 +0,0 @@
import os
import unittest
from unittest.mock import MagicMock, patch
from embedchain import App
from embedchain.config import AppConfig, QueryConfig
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App(config=AppConfig(collect_metrics=False))
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query")
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
mock_answer.assert_called_once()
@patch("openai.ChatCompletion.create")
def test_query_config_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig()
chat_config = QueryConfig(system_prompt="Test system prompt")
app = App(config=config)
app.get_llm_model_answer("Test query", chat_config)
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertEqual(messages_arg[0]["role"], "system")
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
# TODO: Add tests for other config variables
@patch("openai.ChatCompletion.create")
def test_app_passing(self, mock_create):
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
config = AppConfig()
chat_config = QueryConfig()
app = App(config=config, system_prompt="Test system prompt")
app.get_llm_model_answer("Test query", chat_config)
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
messages_arg = mock_create.call_args.kwargs["messages"]
self.assertEqual(messages_arg[0]["role"], "system")
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_params(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query", where={"attribute": "value"})
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
mock_answer.assert_called_once()
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query_with_where_in_query_config(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
a where filter and 'get_llm_model_answer' returns an expected answer string.
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
'get_llm_model_answer' methods appropriately and return the right answer.
Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
queryConfig = QueryConfig(where={"attribute": "value"})
answer = self.app.query("Test query", queryConfig)
self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
mock_answer.assert_called_once()