feat: add multi-document answers (#63)
This commit is contained in:
@@ -133,43 +133,44 @@ class EmbedChain:
|
||||
def get_llm_model_answer(self, prompt):
|
||||
raise NotImplementedError
|
||||
|
||||
def retrieve_from_database(self, input_query):
|
||||
def retrieve_from_database(self, input_query, config: QueryConfig):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
Gets relevant doc based on the query
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param config: The query configuration.
|
||||
:return: The content of the document that matched your query.
|
||||
"""
|
||||
result = self.collection.query(
|
||||
query_texts=[
|
||||
input_query,
|
||||
],
|
||||
n_results=1,
|
||||
n_results=config.number_documents,
|
||||
)
|
||||
result_formatted = self._format_result(result)
|
||||
if result_formatted:
|
||||
content = result_formatted[0][0].page_content
|
||||
else:
|
||||
content = ""
|
||||
return content
|
||||
results_formatted = self._format_result(result)
|
||||
contents = [result[0].page_content for result in results_formatted]
|
||||
return contents
|
||||
|
||||
def generate_prompt(self, input_query, context, config: QueryConfig):
|
||||
def generate_prompt(self, input_query, contexts, config: QueryConfig):
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
passed to an LLM
|
||||
|
||||
:param input_query: The query to use.
|
||||
:param context: Similar documents to the query used as context.
|
||||
:param contexts: List of similar documents to the query used as context.
|
||||
:param config: Optional. The `QueryConfig` instance to use as
|
||||
configuration options.
|
||||
:return: The prompt
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
if not config.history:
|
||||
prompt = config.template.substitute(context=context, query=input_query)
|
||||
prompt = config.template.substitute(
|
||||
context=context_string, query=input_query
|
||||
)
|
||||
else:
|
||||
prompt = config.template.substitute(
|
||||
context=context, query=input_query, history=config.history
|
||||
context=context_string, query=input_query, history=config.history
|
||||
)
|
||||
return prompt
|
||||
|
||||
@@ -198,8 +199,8 @@ class EmbedChain:
|
||||
"""
|
||||
if config is None:
|
||||
config = QueryConfig()
|
||||
context = self.retrieve_from_database(input_query)
|
||||
prompt = self.generate_prompt(input_query, context, config)
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
prompt = self.generate_prompt(input_query, contexts, config)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
logging.info(f"Answer: {answer}")
|
||||
@@ -217,16 +218,18 @@ class EmbedChain:
|
||||
configuration options.
|
||||
:return: The answer to the query.
|
||||
"""
|
||||
context = self.retrieve_from_database(input_query)
|
||||
if config is None:
|
||||
config = ChatConfig()
|
||||
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
|
||||
global memory
|
||||
chat_history = memory.load_memory_variables({})["history"]
|
||||
|
||||
if config is None:
|
||||
config = ChatConfig()
|
||||
if chat_history:
|
||||
config.set_history(chat_history)
|
||||
|
||||
prompt = self.generate_prompt(input_query, context, config)
|
||||
prompt = self.generate_prompt(input_query, contexts, config)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
|
||||
@@ -264,8 +267,8 @@ class EmbedChain:
|
||||
"""
|
||||
if config is None:
|
||||
config = QueryConfig()
|
||||
context = self.retrieve_from_database(input_query)
|
||||
prompt = self.generate_prompt(input_query, context, config)
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
prompt = self.generate_prompt(input_query, contexts, config)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user