refactor: classes and configs (#528)
This commit is contained in:
@@ -1,99 +0,0 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
|
||||
@patch.object(App, "get_answer_from_llm", return_value="Test answer")
|
||||
def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
|
||||
memory.
|
||||
The 'chat' method is called twice. The first call initializes the chat history memory.
|
||||
The second call is expected to use the chat history from the first call.
|
||||
|
||||
Key assumptions tested:
|
||||
called with correct arguments, adding the correct chat history.
|
||||
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
|
||||
- During the second call, the 'chat' method uses the chat history from the first call.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
|
||||
'memory' methods.
|
||||
"""
|
||||
app = App()
|
||||
first_answer = app.chat("Test query 1")
|
||||
self.assertEqual(first_answer, "Test answer")
|
||||
self.assertEqual(len(app.memory.chat_memory.messages), 2)
|
||||
second_answer = app.chat("Test query 2")
|
||||
self.assertEqual(second_answer, "Test answer")
|
||||
self.assertEqual(len(app.memory.chat_memory.messages), 4)
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_params(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.chat("Test chat", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_chat_with_where_in_chat_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'chat' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
|
||||
in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'chat' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
chatConfig = ChatConfig(where={"attribute": "value"})
|
||||
answer = self.app.chat("Test chat", chatConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
|
||||
mock_answer.assert_called_once()
|
||||
@@ -3,8 +3,7 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, CustomAppConfig
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
from embedchain.config import AppConfig, ChromaDbConfig
|
||||
|
||||
|
||||
class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
@@ -13,8 +12,8 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
@patch("chromadb.api.models.Collection.Collection.add")
|
||||
@patch("chromadb.api.models.Collection.Collection.get")
|
||||
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
|
||||
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm")
|
||||
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer")
|
||||
@patch("embedchain.llm.base_llm.BaseLlm.get_answer_from_llm")
|
||||
@patch("embedchain.llm.base_llm.BaseLlm.get_llm_model_answer")
|
||||
def test_whole_app(
|
||||
self,
|
||||
_mock_get,
|
||||
@@ -43,17 +42,14 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
|
||||
"""
|
||||
Test if the `App` instance is correctly reconstructed after a reset.
|
||||
"""
|
||||
app = App(
|
||||
CustomAppConfig(
|
||||
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
|
||||
)
|
||||
)
|
||||
config = AppConfig(log_level="DEBUG", collect_metrics=False)
|
||||
app = App(config=config, chromadb_config=ChromaDbConfig(chroma_settings={"allow_reset": True}))
|
||||
app.reset()
|
||||
|
||||
# Make sure the client is still healthy
|
||||
app.db.client.heartbeat()
|
||||
# Make sure the collection exists, and can be added to
|
||||
app.collection.add(
|
||||
app.db.collection.add(
|
||||
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
|
||||
metadatas=[
|
||||
{"chapter": "3", "verse": "16"},
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import unittest
|
||||
from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, QueryConfig
|
||||
|
||||
|
||||
class TestGeneratePrompt(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
def test_generate_prompt_with_template(self):
|
||||
"""
|
||||
Tests that the generate_prompt method correctly formats the prompt using
|
||||
a custom template provided in the QueryConfig instance.
|
||||
|
||||
This test sets up a scenario with an input query and a list of contexts,
|
||||
and a custom template, and then calls generate_prompt. It checks that the
|
||||
returned prompt correctly incorporates all the contexts and the query into
|
||||
the format specified by the template.
|
||||
"""
|
||||
# Setup
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
|
||||
config = QueryConfig(template=Template(template))
|
||||
|
||||
# Execute
|
||||
result = self.app.generate_prompt(input_query, contexts, config)
|
||||
|
||||
# Assert
|
||||
expected_result = (
|
||||
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
|
||||
)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_prompt_with_contexts_list(self):
|
||||
"""
|
||||
Tests that the generate_prompt method correctly handles a list of contexts.
|
||||
|
||||
This test sets up a scenario with an input query and a list of contexts,
|
||||
and then calls generate_prompt. It checks that the returned prompt
|
||||
correctly includes all the contexts and the query.
|
||||
"""
|
||||
# Setup
|
||||
input_query = "Test query"
|
||||
contexts = ["Context 1", "Context 2", "Context 3"]
|
||||
config = QueryConfig()
|
||||
|
||||
# Execute
|
||||
result = self.app.generate_prompt(input_query, contexts, config)
|
||||
|
||||
# Assert
|
||||
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_prompt_with_history(self):
|
||||
"""
|
||||
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
|
||||
"""
|
||||
config = QueryConfig(history=["Past context 1", "Past context 2"])
|
||||
config.template = Template("Context: $context | Query: $query | History: $history")
|
||||
prompt = self.app.generate_prompt("Test query", ["Test context"], config)
|
||||
|
||||
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
|
||||
self.assertEqual(prompt, expected_prompt)
|
||||
@@ -1,137 +0,0 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, QueryConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
|
||||
def setUp(self):
|
||||
self.app = App(config=AppConfig(collect_metrics=False))
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
|
||||
'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
|
||||
appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.query("Test query")
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("openai.ChatCompletion.create")
|
||||
def test_query_config_app_passing(self, mock_create):
|
||||
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
|
||||
|
||||
config = AppConfig()
|
||||
chat_config = QueryConfig(system_prompt="Test system prompt")
|
||||
app = App(config=config)
|
||||
|
||||
app.get_llm_model_answer("Test query", chat_config)
|
||||
|
||||
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
|
||||
messages_arg = mock_create.call_args.kwargs["messages"]
|
||||
self.assertEqual(messages_arg[0]["role"], "system")
|
||||
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
||||
|
||||
# TODO: Add tests for other config variables
|
||||
|
||||
@patch("openai.ChatCompletion.create")
|
||||
def test_app_passing(self, mock_create):
|
||||
mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
|
||||
|
||||
config = AppConfig()
|
||||
chat_config = QueryConfig()
|
||||
app = App(config=config, system_prompt="Test system prompt")
|
||||
|
||||
app.get_llm_model_answer("Test query", chat_config)
|
||||
|
||||
# Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
|
||||
messages_arg = mock_create.call_args.kwargs["messages"]
|
||||
self.assertEqual(messages_arg[0]["role"], "system")
|
||||
self.assertEqual(messages_arg[0]["content"], "Test system prompt")
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_params(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
answer = self.app.query("Test query", where={"attribute": "value"})
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
|
||||
mock_answer.assert_called_once()
|
||||
|
||||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
|
||||
def test_query_with_where_in_query_config(self):
|
||||
"""
|
||||
This test checks the functionality of the 'query' method in the App class.
|
||||
It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
|
||||
a where filter and 'get_llm_model_answer' returns an expected answer string.
|
||||
|
||||
The 'query' method is expected to call 'retrieve_from_database' with the where filter and
|
||||
'get_llm_model_answer' methods appropriately and return the right answer.
|
||||
|
||||
Key assumptions tested:
|
||||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
|
||||
QueryConfig.
|
||||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
|
||||
- 'query' method returns the value it received from 'get_llm_model_answer'.
|
||||
|
||||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
|
||||
'get_llm_model_answer' methods.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
|
||||
mock_answer.return_value = "Test answer"
|
||||
queryConfig = QueryConfig(where={"attribute": "value"})
|
||||
answer = self.app.query("Test query", queryConfig)
|
||||
|
||||
self.assertEqual(answer, "Test answer")
|
||||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
|
||||
self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
|
||||
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
|
||||
mock_answer.assert_called_once()
|
||||
Reference in New Issue
Block a user