feat: add multi-document answers (#63)

This commit is contained in:
cachho
2023-07-11 14:50:07 +02:00
committed by GitHub
parent 9409e5605a
commit 40dc28406d
3 changed files with 32 additions and 21 deletions

View File

@@ -133,43 +133,44 @@ class EmbedChain:
def get_llm_model_answer(self, prompt):
raise NotImplementedError
def retrieve_from_database(self, input_query):
def retrieve_from_database(self, input_query, config: QueryConfig):
"""
Queries the vector database based on the given input query.
Gets relevant doc based on the query
:param input_query: The query to use.
:param config: The query configuration.
:return: The content of the document that matched your query.
"""
result = self.collection.query(
query_texts=[
input_query,
],
n_results=1,
n_results=config.number_documents,
)
result_formatted = self._format_result(result)
if result_formatted:
content = result_formatted[0][0].page_content
else:
content = ""
return content
results_formatted = self._format_result(result)
contents = [result[0].page_content for result in results_formatted]
return contents
def generate_prompt(self, input_query, context, config: QueryConfig):
def generate_prompt(self, input_query, contexts, config: QueryConfig):
"""
Generates a prompt based on the given query and context, ready to be
passed to an LLM
:param input_query: The query to use.
:param context: Similar documents to the query used as context.
:param contexts: List of similar documents to the query used as context.
:param config: Optional. The `QueryConfig` instance to use as
configuration options.
:return: The prompt
"""
context_string = (" | ").join(contexts)
if not config.history:
prompt = config.template.substitute(context=context, query=input_query)
prompt = config.template.substitute(
context=context_string, query=input_query
)
else:
prompt = config.template.substitute(
context=context, query=input_query, history=config.history
context=context_string, query=input_query, history=config.history
)
return prompt
@@ -198,8 +199,8 @@ class EmbedChain:
"""
if config is None:
config = QueryConfig()
context = self.retrieve_from_database(input_query)
prompt = self.generate_prompt(input_query, context, config)
contexts = self.retrieve_from_database(input_query, config)
prompt = self.generate_prompt(input_query, contexts, config)
logging.info(f"Prompt: {prompt}")
answer = self.get_answer_from_llm(prompt, config)
logging.info(f"Answer: {answer}")
@@ -217,16 +218,18 @@ class EmbedChain:
configuration options.
:return: The answer to the query.
"""
context = self.retrieve_from_database(input_query)
if config is None:
config = ChatConfig()
contexts = self.retrieve_from_database(input_query, config)
global memory
chat_history = memory.load_memory_variables({})["history"]
if config is None:
config = ChatConfig()
if chat_history:
config.set_history(chat_history)
prompt = self.generate_prompt(input_query, context, config)
prompt = self.generate_prompt(input_query, contexts, config)
logging.info(f"Prompt: {prompt}")
answer = self.get_answer_from_llm(prompt, config)
@@ -264,8 +267,8 @@ class EmbedChain:
"""
if config is None:
config = QueryConfig()
context = self.retrieve_from_database(input_query)
prompt = self.generate_prompt(input_query, context, config)
contexts = self.retrieve_from_database(input_query, config)
prompt = self.generate_prompt(input_query, contexts, config)
logging.info(f"Prompt: {prompt}")
return prompt