diff --git a/examples/chat-pdf/app.py b/examples/chat-pdf/app.py index b30e99e8..353d508d 100644 --- a/examples/chat-pdf/app.py +++ b/examples/chat-pdf/app.py @@ -12,8 +12,7 @@ from embedchain.helpers.callbacks import (StreamingStdOutCallbackHandlerYield, generate) -@st.cache_resource -def embedchain_bot(): +def embedchain_bot(db_path, api_key): return App.from_config( config={ "llm": { @@ -24,31 +23,43 @@ def embedchain_bot(): "max_tokens": 1000, "top_p": 1, "stream": True, + "api_key": api_key, }, }, "vectordb": { "provider": "chroma", - "config": {"collection_name": "chat-pdf", "dir": "db", "allow_reset": True}, + "config": {"collection_name": "chat-pdf", "dir": db_path, "allow_reset": True}, }, + "embedder": {"provider": "openai", "config": {"api_key": api_key}}, "chunker": {"chunk_size": 2000, "chunk_overlap": 0, "length_function": "len"}, } ) -@st.cache_data -def update_openai_key(): - os.environ["OPENAI_API_KEY"] = st.session_state.chatbot_api_key +def get_db_path(): + tmpdirname = tempfile.mkdtemp() + return tmpdirname + + +def get_ec_app(api_key): + if "app" in st.session_state: + print("Found app in session state") + app = st.session_state.app + else: + print("Creating app") + db_path = get_db_path() + app = embedchain_bot(db_path, api_key) + st.session_state.app = app + return app with st.sidebar: - openai_access_token = st.text_input( - "OpenAI API Key", value=os.environ.get("OPENAI_API_KEY"), key="chatbot_api_key", type="password" - ) # noqa: E501 + openai_access_token = st.text_input("OpenAI API Key", key="api_key", type="password") "WE DO NOT STORE YOUR OPENAI KEY." "Just paste your OpenAI API key here and we'll use it to power the chatbot. [Get your OpenAI API key](https://platform.openai.com/api-keys)" # noqa: E501 - if openai_access_token: - update_openai_key() + if st.session_state.api_key: + app = get_ec_app(st.session_state.api_key) pdf_files = st.file_uploader("Upload your PDF files", accept_multiple_files=True, type="pdf") add_pdf_files = st.session_state.get("add_pdf_files", []) @@ -57,10 +68,9 @@ with st.sidebar: if file_name in add_pdf_files: continue try: - if not os.environ.get("OPENAI_API_KEY"): + if not st.session_state.api_key: st.error("Please enter your OpenAI API Key") st.stop() - app = embedchain_bot() temp_file_name = None with tempfile.NamedTemporaryFile(mode="wb", delete=False, prefix=file_name, suffix=".pdf") as f: f.write(pdf_file.getvalue()) @@ -97,11 +107,12 @@ for message in st.session_state.messages: st.markdown(message["content"]) if prompt := st.chat_input("Ask me anything!"): - if not os.environ.get("OPENAI_API_KEY"): + if not st.session_state.api_key: st.error("Please enter your OpenAI API Key", icon="🤖") st.stop() - app = embedchain_bot() + app = get_ec_app(st.session_state.api_key) + with st.chat_message("user"): st.session_state.messages.append({"role": "user", "content": prompt}) st.markdown(prompt) @@ -146,5 +157,5 @@ if prompt := st.chat_input("Ask me anything!"): full_response += f"- {source}\n" msg_placeholder.markdown(full_response) - print("Answer: ", answer) - st.session_state.messages.append({"role": "assistant", "content": answer}) + print("Answer: ", full_response) + st.session_state.messages.append({"role": "assistant", "content": full_response})