Files
t6_mem0/tests/embedchain/test_dryrun.py
2023-07-15 17:28:51 -07:00

53 lines
1.9 KiB
Python

import os
import unittest
from string import Template
from unittest.mock import patch
from embedchain import App
from embedchain.embedchain import QueryConfig
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App()
@patch("logging.info")
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
"""
Test that the 'query' method logs the same prompt as the 'dry_run' method.
This is the only way I found to test the prompt in query, that's not returned.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
input_query = "Test query"
config = QueryConfig(
number_documents=3,
template=Template("Question: $query, context: $context, history: $history"),
history=["Past context 1", "Past context 2"],
)
with patch.object(self.app, "get_answer_from_llm"):
self.app.dry_run(input_query, config)
self.app.query(input_query, config)
# Access the log messages captured during the execution
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
# Extract the prompts from the log messages
dry_run_prompt = self.extract_prompt(logged_messages[0])
query_prompt = self.extract_prompt(logged_messages[1])
# Perform assertions on the prompts
self.assertEqual(dry_run_prompt, query_prompt)
def extract_prompt(self, log_message):
"""
Extracts the prompt value from the log message.
Adjust this method based on the log message format in your implementation.
"""
# Modify this logic based on your log message format
prefix = "Prompt: "
return log_message.split(prefix, 1)[1]