53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
import os
|
|
import unittest
|
|
from string import Template
|
|
from unittest.mock import patch
|
|
|
|
from embedchain import App
|
|
from embedchain.embedchain import QueryConfig
|
|
|
|
|
|
class TestApp(unittest.TestCase):
|
|
os.environ["OPENAI_API_KEY"] = "test_key"
|
|
|
|
def setUp(self):
|
|
self.app = App()
|
|
|
|
@patch("logging.info")
|
|
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
|
|
"""
|
|
Test that the 'query' method logs the same prompt as the 'dry_run' method.
|
|
This is the only way I found to test the prompt in query, that's not returned.
|
|
"""
|
|
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
|
mock_retrieve.return_value = ["Test context"]
|
|
input_query = "Test query"
|
|
config = QueryConfig(
|
|
number_documents=3,
|
|
template=Template("Question: $query, context: $context, history: $history"),
|
|
history=["Past context 1", "Past context 2"],
|
|
)
|
|
|
|
with patch.object(self.app, "get_answer_from_llm"):
|
|
self.app.dry_run(input_query, config)
|
|
self.app.query(input_query, config)
|
|
|
|
# Access the log messages captured during the execution
|
|
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
|
|
|
|
# Extract the prompts from the log messages
|
|
dry_run_prompt = self.extract_prompt(logged_messages[0])
|
|
query_prompt = self.extract_prompt(logged_messages[1])
|
|
|
|
# Perform assertions on the prompts
|
|
self.assertEqual(dry_run_prompt, query_prompt)
|
|
|
|
def extract_prompt(self, log_message):
|
|
"""
|
|
Extracts the prompt value from the log message.
|
|
Adjust this method based on the log message format in your implementation.
|
|
"""
|
|
# Modify this logic based on your log message format
|
|
prefix = "Prompt: "
|
|
return log_message.split(prefix, 1)[1]
|