diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index ae61418f..e1641454 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -9,8 +9,7 @@ from langchain.docstore.document import Document from langchain.memory import ConversationBufferMemory from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig -from embedchain.config.QueryConfig import (CODE_DOCS_PAGE_PROMPT_TEMPLATE, - DEFAULT_PROMPT) +from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT from embedchain.data_formatter import DataFormatter gpt4all_model = None @@ -37,6 +36,7 @@ class EmbedChain: self.collection = self.config.db.collection self.user_asks = [] self.is_code_docs_instance = False + self.online = False def add(self, data_type, url, metadata=None, config: AddConfig = None): """ @@ -163,7 +163,10 @@ class EmbedChain: contents = [result[0].page_content for result in results_formatted] return contents - def generate_prompt(self, input_query, contexts, config: QueryConfig): + def _append_search_and_context(self, context, web_search_result): + return f"{context}\nWeb Search Result: {web_search_result}" + + def generate_prompt(self, input_query, contexts, config: QueryConfig, **kwargs): """ Generates a prompt based on the given query and context, ready to be passed to an LLM @@ -175,6 +178,9 @@ class EmbedChain: :return: The prompt """ context_string = (" | ").join(contexts) + web_search_result = kwargs.get("web_search_result", "") + if web_search_result: + context_string = self._append_search_and_context(context_string, web_search_result) if not config.history: prompt = config.template.substitute(context=context_string, query=input_query) else: @@ -193,6 +199,12 @@ class EmbedChain: return self.get_llm_model_answer(prompt, config) + def access_search_and_get_results(self, input_query): + from langchain.tools import DuckDuckGoSearchRun + search = DuckDuckGoSearchRun() + logging.info(f"Access search to get answers for {input_query}") + return search.run(input_query) + def query(self, input_query, config: QueryConfig = None): """ Queries the vector database based on the given input query. @@ -209,8 +221,11 @@ class EmbedChain: if self.is_code_docs_instance: config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE config.number_documents = 5 + k = {} + if self.online: + k["web_search_result"] = self.access_search_and_get_results(input_query) contexts = self.retrieve_from_database(input_query, config) - prompt = self.generate_prompt(input_query, contexts, config) + prompt = self.generate_prompt(input_query, contexts, config, **k) logging.info(f"Prompt: {prompt}") answer = self.get_answer_from_llm(prompt, config) @@ -245,7 +260,10 @@ class EmbedChain: if self.is_code_docs_instance: config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE config.number_documents = 5 - contexts = self.retrieve_from_database(input_query, config) + k = {} + if self.online: + k["web_search_result"] = self.access_search_and_get_results(input_query) + contexts = self.retrieve_from_database(input_query, config, **k) global memory chat_history = memory.load_memory_variables({})["history"] @@ -253,7 +271,7 @@ class EmbedChain: if chat_history: config.set_history(chat_history) - prompt = self.generate_prompt(input_query, contexts, config) + prompt = self.generate_prompt(input_query, contexts, config, **k) logging.info(f"Prompt: {prompt}") answer = self.get_answer_from_llm(prompt, config)