[BREAKING CHANGE] moved dry run into query and chat (#329)
Co-authored-by: Aaishik Dutta <aaishikdutta@Aaishiks-MacBook-Pro.local>
This commit is contained in:
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
import unittest
|
||||
from string import Template
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.embedchain import QueryConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
|
||||
def setUp(self):
|
||||
self.app = App()
|
||||
|
||||
@patch("logging.info")
|
||||
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
|
||||
"""
|
||||
Test that the 'query' method logs the same prompt as the 'dry_run' method.
|
||||
This is the only way I found to test the prompt in query, that's not returned.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
input_query = "Test query"
|
||||
config = QueryConfig(
|
||||
number_documents=3,
|
||||
template=Template("Question: $query, context: $context, history: $history"),
|
||||
history=["Past context 1", "Past context 2"],
|
||||
)
|
||||
|
||||
with patch.object(self.app, "get_answer_from_llm"):
|
||||
self.app.dry_run(input_query, config)
|
||||
self.app.query(input_query, config)
|
||||
|
||||
# Access the log messages captured during the execution
|
||||
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
|
||||
|
||||
# Extract the prompts from the log messages
|
||||
dry_run_prompt = self.extract_prompt(logged_messages[0])
|
||||
query_prompt = self.extract_prompt(logged_messages[1])
|
||||
|
||||
# Perform assertions on the prompts
|
||||
self.assertEqual(dry_run_prompt, query_prompt)
|
||||
|
||||
def extract_prompt(self, log_message):
|
||||
"""
|
||||
Extracts the prompt value from the log message.
|
||||
Adjust this method based on the log message format in your implementation.
|
||||
"""
|
||||
# Modify this logic based on your log message format
|
||||
prefix = "Prompt: "
|
||||
return log_message.split(prefix, 1)[1]
|
||||
Reference in New Issue
Block a user