[BREAKING CHANGE] moved dry run into query and chat (#329)

Co-authored-by: Aaishik Dutta <aaishikdutta@Aaishiks-MacBook-Pro.local>
This commit is contained in:
aaishikdutta
2023-07-20 11:55:41 +05:30
committed by GitHub
parent 6b61b7e9c1
commit 4bb06147c1
4 changed files with 28 additions and 81 deletions

View File

@@ -34,6 +34,11 @@ print(naval_chat_bot.chat("what did the author say about happiness?"))
# answer: The author, Naval Ravikant, believes that happiness is a choice you make and a skill you develop. He compares the mind to the body, stating that just as the body can be molded and changed, so can the mind. He emphasizes the importance of being present in the moment and not getting caught up in regrets of the past or worries about the future. By being present and grateful for where you are, you can experience true happiness.
```
#### Dry Run
Dry Run is an option in the `query` and `chat` methods that allows the user to not send their constructed prompt to the LLM, to save money. It's used for [testing](/advanced/testing#dry-run).
### Stream Response
- You can add config to your query method to stream responses like ChatGPT does. You would require a downstream handler to render the chunk in your desirable format. Supports both OpenAI model and OpenSourceApp. 📊
@@ -52,10 +57,6 @@ for chunk in resp:
### Other Methods
#### Dry Run
Dry run has all the options that `query` has, it just doesn't send the prompt to the LLM, to save money. It's used for [testing](/advanced/testing#dry-run).
#### Reset
Resets the database and deletes all embeddings. Irreversible. Requires reinitialization afterwards.

View File

@@ -8,12 +8,12 @@ title: '🧪 Testing'
Before you consume valueable tokens, you should make sure that the embedding you have done works and that it's receiving the correct document from the database.
For this you can use the `dry_run` method.
For this you can use the `dry_run` option in your `query` or `chat` method.
Following the example above, add this to your script:
```python
print(naval_chat_bot.dry_run('Can you tell me who Naval Ravikant is?'))
print(naval_chat_bot.query('Can you tell me who Naval Ravikant is?', dry_run=True))
'''
Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

View File

@@ -213,7 +213,7 @@ class EmbedChain:
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
def query(self, input_query, config: QueryConfig = None):
def query(self, input_query, config: QueryConfig = None, dry_run=False):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -222,6 +222,12 @@ class EmbedChain:
:param input_query: The query to use.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:return: The answer to the query.
"""
if config is None:
@@ -236,6 +242,9 @@ class EmbedChain:
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt, config)
if isinstance(answer, str):
@@ -251,7 +260,7 @@ class EmbedChain:
yield chunk
logging.info(f"Answer: {streamed_answer}")
def chat(self, input_query, config: ChatConfig = None):
def chat(self, input_query, config: ChatConfig = None, dry_run=False):
"""
Queries the vector database on the given input query.
Gets relevant doc based on the query and then passes it to an
@@ -261,6 +270,12 @@ class EmbedChain:
:param input_query: The query to use.
:param config: Optional. The `ChatConfig` instance to use as
configuration options.
:param dry_run: Optional. A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:return: The answer to the query.
"""
if config is None:
@@ -281,6 +296,10 @@ class EmbedChain:
prompt = self.generate_prompt(input_query, contexts, config, **k)
logging.info(f"Prompt: {prompt}")
if dry_run:
return prompt
answer = self.get_answer_from_llm(prompt, config)
memory.chat_memory.add_user_message(input_query)
@@ -301,27 +320,6 @@ class EmbedChain:
memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")
def dry_run(self, input_query, config: QueryConfig = None):
"""
A dry run does everything except send the resulting prompt to
the LLM. The purpose is to test the prompt, not the response.
You can use it to test your prompt, including the context provided
by the vector database's doc retrieval.
The only thing the dry run does not consider is the cut-off due to
the `max_tokens` parameter.
:param input_query: The query to use.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:return: The prompt that would be sent to the LLM
"""
if config is None:
config = QueryConfig()
contexts = self.retrieve_from_database(input_query, config)
prompt = self.generate_prompt(input_query, contexts, config)
logging.info(f"Prompt: {prompt}")
return prompt
def count(self):
"""
Count the number of embeddings.

View File

@@ -1,52 +0,0 @@
import os
import unittest
from string import Template
from unittest.mock import patch
from embedchain import App
from embedchain.embedchain import QueryConfig
class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"
def setUp(self):
self.app = App()
@patch("logging.info")
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
"""
Test that the 'query' method logs the same prompt as the 'dry_run' method.
This is the only way I found to test the prompt in query, that's not returned.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
input_query = "Test query"
config = QueryConfig(
number_documents=3,
template=Template("Question: $query, context: $context, history: $history"),
history=["Past context 1", "Past context 2"],
)
with patch.object(self.app, "get_answer_from_llm"):
self.app.dry_run(input_query, config)
self.app.query(input_query, config)
# Access the log messages captured during the execution
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
# Extract the prompts from the log messages
dry_run_prompt = self.extract_prompt(logged_messages[0])
query_prompt = self.extract_prompt(logged_messages[1])
# Perform assertions on the prompts
self.assertEqual(dry_run_prompt, query_prompt)
def extract_prompt(self, log_message):
"""
Extracts the prompt value from the log message.
Adjust this method based on the log message format in your implementation.
"""
# Modify this logic based on your log message format
prefix = "Prompt: "
return log_message.split(prefix, 1)[1]