🚀 An Embedchain app powered by OpenAI!
' # noqa: E501 +st.markdown(styled_caption, unsafe_allow_html=True) + +if "messages" not in st.session_state: + st.session_state.messages = [ + { + "role": "assistant", + "content": """ + Hi! I'm chatbot powered by Embedchain, which can answer questions about your pdf documents.\n + Upload your pdf documents here and I'll answer your questions about them! + """, + } + ] + +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + +if prompt := st.chat_input("Ask me anything!"): + if not os.environ.get("OPENAI_API_KEY"): + st.error("Please enter your OpenAI API Key", icon="🤖") + st.stop() + + app = embedchain_bot() + with st.chat_message("user"): + st.session_state.messages.append({"role": "user", "content": prompt}) + st.markdown(prompt) + + with st.chat_message("assistant"): + msg_placeholder = st.empty() + msg_placeholder.markdown("Thinking...") + full_response = "" + + q = queue.Queue() + + def app_response(result): + llm_config = app.llm.config.as_dict() + llm_config["callbacks"] = [StreamingStdOutCallbackHandlerYield(q=q)] + config = BaseLlmConfig(**llm_config) + answer, citations = app.chat(prompt, config=config, citations=True) + result["answer"] = answer + result["citations"] = citations + + results = {} + thread = threading.Thread(target=app_response, args=(results,)) + thread.start() + + for answer_chunk in generate(q): + full_response += answer_chunk + msg_placeholder.markdown(full_response) + + thread.join() + answer, citations = results["answer"], results["citations"] + if citations: + full_response += "\n\n**Sources**:\n" + sources = [] + for i, citation in enumerate(citations): + source = citation[1] + pattern = re.compile(r"([^/]+)\.[^\.]+\.pdf$") + match = pattern.search(source) + if match: + source = match.group(1) + ".pdf" + sources.append(source) + sources = list(set(sources)) + for source in sources: + full_response += f"- {source}\n" + + msg_placeholder.markdown(full_response) + print("Answer: ", answer) + st.session_state.messages.append({"role": "assistant", "content": answer}) diff --git a/examples/chat-pdf/embedchain.json b/examples/chat-pdf/embedchain.json new file mode 100644 index 00000000..32dec293 --- /dev/null +++ b/examples/chat-pdf/embedchain.json @@ -0,0 +1,3 @@ +{ + "provider": "streamlit.io" +} \ No newline at end of file diff --git a/examples/chat-pdf/requirements.txt b/examples/chat-pdf/requirements.txt new file mode 100644 index 00000000..b864076a --- /dev/null +++ b/examples/chat-pdf/requirements.txt @@ -0,0 +1,2 @@ +streamlit==1.29.0 +embedchain diff --git a/tests/llm/test_ollama.py b/tests/llm/test_ollama.py index 394bbd41..34ab8238 100644 --- a/tests/llm/test_ollama.py +++ b/tests/llm/test_ollama.py @@ -9,6 +9,7 @@ def ollama_llm_config(): config = BaseLlmConfig(model="llama2", temperature=0.7, top_p=0.8, stream=True, system_prompt=None) yield config + def test_get_llm_model_answer(ollama_llm_config, mocker): mocker.patch("embedchain.llm.ollama.OllamaLlm._get_answer", return_value="Test answer") @@ -33,6 +34,6 @@ def test_get_answer_mocked_ollama(ollama_llm_config, mocker): system=None, temperature=0.7, top_p=0.8, - callback_manager=mocker.ANY # Use mocker.ANY to ignore the exact instance + callback_manager=mocker.ANY, # Use mocker.ANY to ignore the exact instance ) mock_instance.assert_called_once_with(prompt)