[BREAKING CHANGE] moved dry run into query and chat (#329)
Co-authored-by: Aaishik Dutta <aaishikdutta@Aaishiks-MacBook-Pro.local>
This commit is contained in:
@@ -34,6 +34,11 @@ print(naval_chat_bot.chat("what did the author say about happiness?"))
|
||||
# answer: The author, Naval Ravikant, believes that happiness is a choice you make and a skill you develop. He compares the mind to the body, stating that just as the body can be molded and changed, so can the mind. He emphasizes the importance of being present in the moment and not getting caught up in regrets of the past or worries about the future. By being present and grateful for where you are, you can experience true happiness.
|
||||
```
|
||||
|
||||
#### Dry Run
|
||||
|
||||
Dry Run is an option in the `query` and `chat` methods that allows the user to not send their constructed prompt to the LLM, to save money. It's used for [testing](/advanced/testing#dry-run).
|
||||
|
||||
|
||||
### Stream Response
|
||||
|
||||
- You can add config to your query method to stream responses like ChatGPT does. You would require a downstream handler to render the chunk in your desirable format. Supports both OpenAI model and OpenSourceApp. 📊
|
||||
@@ -52,10 +57,6 @@ for chunk in resp:
|
||||
|
||||
### Other Methods
|
||||
|
||||
#### Dry Run
|
||||
|
||||
Dry run has all the options that `query` has, it just doesn't send the prompt to the LLM, to save money. It's used for [testing](/advanced/testing#dry-run).
|
||||
|
||||
#### Reset
|
||||
|
||||
Resets the database and deletes all embeddings. Irreversible. Requires reinitialization afterwards.
|
||||
|
||||
@@ -8,12 +8,12 @@ title: '🧪 Testing'
|
||||
|
||||
Before you consume valueable tokens, you should make sure that the embedding you have done works and that it's receiving the correct document from the database.
|
||||
|
||||
For this you can use the `dry_run` method.
|
||||
For this you can use the `dry_run` option in your `query` or `chat` method.
|
||||
|
||||
Following the example above, add this to your script:
|
||||
|
||||
```python
|
||||
print(naval_chat_bot.dry_run('Can you tell me who Naval Ravikant is?'))
|
||||
print(naval_chat_bot.query('Can you tell me who Naval Ravikant is?', dry_run=True))
|
||||
|
||||
'''
|
||||
Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
@@ -213,7 +213,7 @@ class EmbedChain:
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None):
|
||||
def query(self, input_query, config: QueryConfig = None, dry_run=False):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -222,6 +222,12 @@ class EmbedChain:
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
@@ -236,6 +242,9 @@ class EmbedChain:
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
|
||||
if isinstance(answer, str):
|
||||
@@ -251,7 +260,7 @@ class EmbedChain:
|
||||
yield chunk
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def chat(self, input_query, config: ChatConfig = None):
|
||||
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
|
||||
"""
|
||||
Queries the vector database on the given input query.
|
||||
Gets relevant doc based on the query and then passes it to an
|
||||
@@ -261,6 +270,12 @@ class EmbedChain:
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `ChatConfig` instance to use as
|
||||
configuration options.
|
||||
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
if config is None:
|
||||
@@ -281,6 +296,10 @@ class EmbedChain:
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
if dry_run:
|
||||
return prompt
|
||||
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
|
||||
memory.chat_memory.add_user_message(input_query)
|
||||
@@ -301,27 +320,6 @@ class EmbedChain:
|
||||
memory.chat_memory.add_ai_message(streamed_answer)
|
||||
logging.info(f"Answer: {streamed_answer}")
|
||||
|
||||
def dry_run(self, input_query, config: QueryConfig = None):
|
||||
"""
|
||||
A dry run does everything except send the resulting prompt to
|
||||
the LLM. The purpose is to test the prompt, not the response.
|
||||
You can use it to test your prompt, including the context provided
|
||||
by the vector database's doc retrieval.
|
||||
The only thing the dry run does not consider is the cut-off due to
|
||||
the `max_tokens` parameter.
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:return: The prompt that would be sent to the LLM
|
||||
"""
|
||||
if config is None:
|
||||
config = QueryConfig()
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
prompt = self.generate_prompt(input_query, contexts, config)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
def count(self):
|
||||
"""
|
||||
Count the number of embeddings.
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
import unittest
|
||||
from string import Template
|
||||
from unittest.mock import patch
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.embedchain import QueryConfig
|
||||
|
||||
|
||||
class TestApp(unittest.TestCase):
|
||||
os.environ["OPENAI_API_KEY"] = "test_key"
|
||||
|
||||
def setUp(self):
|
||||
self.app = App()
|
||||
|
||||
@patch("logging.info")
|
||||
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
|
||||
"""
|
||||
Test that the 'query' method logs the same prompt as the 'dry_run' method.
|
||||
This is the only way I found to test the prompt in query, that's not returned.
|
||||
"""
|
||||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
|
||||
mock_retrieve.return_value = ["Test context"]
|
||||
input_query = "Test query"
|
||||
config = QueryConfig(
|
||||
number_documents=3,
|
||||
template=Template("Question: $query, context: $context, history: $history"),
|
||||
history=["Past context 1", "Past context 2"],
|
||||
)
|
||||
|
||||
with patch.object(self.app, "get_answer_from_llm"):
|
||||
self.app.dry_run(input_query, config)
|
||||
self.app.query(input_query, config)
|
||||
|
||||
# Access the log messages captured during the execution
|
||||
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
|
||||
|
||||
# Extract the prompts from the log messages
|
||||
dry_run_prompt = self.extract_prompt(logged_messages[0])
|
||||
query_prompt = self.extract_prompt(logged_messages[1])
|
||||
|
||||
# Perform assertions on the prompts
|
||||
self.assertEqual(dry_run_prompt, query_prompt)
|
||||
|
||||
def extract_prompt(self, log_message):
|
||||
"""
|
||||
Extracts the prompt value from the log message.
|
||||
Adjust this method based on the log message format in your implementation.
|
||||
"""
|
||||
# Modify this logic based on your log message format
|
||||
prefix = "Prompt: "
|
||||
return log_message.split(prefix, 1)[1]
|
||||
Reference in New Issue
Block a user