feat: Add browse the internet or online functionality. (#291)
This commit is contained in:
@@ -9,8 +9,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
|
||||
from embedchain.config.QueryConfig import (CODE_DOCS_PAGE_PROMPT_TEMPLATE,
|
||||
DEFAULT_PROMPT)
|
||||
from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
|
||||
gpt4all_model = None
|
||||
@@ -37,6 +36,7 @@ class EmbedChain:
|
||||
self.collection = self.config.db.collection
|
||||
self.user_asks = []
|
||||
self.is_code_docs_instance = False
|
||||
self.online = False
|
||||
|
||||
def add(self, data_type, url, metadata=None, config: AddConfig = None):
|
||||
"""
|
||||
@@ -163,7 +163,10 @@ class EmbedChain:
|
||||
contents = [result[0].page_content for result in results_formatted]
|
||||
return contents
|
||||
|
||||
def generate_prompt(self, input_query, contexts, config: QueryConfig):
|
||||
def _append_search_and_context(self, context, web_search_result):
|
||||
return f"{context}\nWeb Search Result: {web_search_result}"
|
||||
|
||||
def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs):
|
||||
"""
|
||||
Generates a prompt based on the given query and context, ready to be
|
||||
passed to an LLM
|
||||
@@ -175,6 +178,9 @@ class EmbedChain:
|
||||
:return: The prompt
|
||||
"""
|
||||
context_string = (" | ").join(contexts)
|
||||
web_search_result = kwargs.get("web_search_result", "")
|
||||
if web_search_result:
|
||||
context_string = self._append_search_and_context(context_string, web_search_result)
|
||||
if not config.history:
|
||||
prompt = config.template.substitute(context=context_string, query=input_query)
|
||||
else:
|
||||
@@ -193,6 +199,12 @@ class EmbedChain:
|
||||
|
||||
return self.get_llm_model_answer(prompt, config)
|
||||
|
||||
def access_search_and_get_results(self, input_query):
|
||||
from langchain.tools import DuckDuckGoSearchRun
|
||||
search = DuckDuckGoSearchRun()
|
||||
logging.info(f"Access search to get answers for {input_query}")
|
||||
return search.run(input_query)
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None):
|
||||
"""
|
||||
Queries the vector database based on the given input query.
|
||||
@@ -209,8 +221,11 @@ class EmbedChain:
|
||||
if self.is_code_docs_instance:
|
||||
config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
||||
config.number_documents = 5
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
prompt = self.generate_prompt(input_query, contexts, config)
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
@@ -245,7 +260,10 @@ class EmbedChain:
|
||||
if self.is_code_docs_instance:
|
||||
config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
||||
config.number_documents = 5
|
||||
contexts = self.retrieve_from_database(input_query, config)
|
||||
k = {}
|
||||
if self.online:
|
||||
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
||||
contexts = self.retrieve_from_database(input_query, config, **k)
|
||||
|
||||
global memory
|
||||
chat_history = memory.load_memory_variables({})["history"]
|
||||
@@ -253,7 +271,7 @@ class EmbedChain:
|
||||
if chat_history:
|
||||
config.set_history(chat_history)
|
||||
|
||||
prompt = self.generate_prompt(input_query, contexts, config)
|
||||
prompt = self.generate_prompt(input_query, contexts, config, **k)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user