From 4388f6bfc27e5512ca39f7665f74870afc78c693 Mon Sep 17 00:00:00 2001 From: Deshraj Yadav Date: Fri, 25 Aug 2023 21:59:39 -0700 Subject: [PATCH] [bug-fix] fix issue related to bot memory when using multiple bots at the same time (#486) --- embedchain/embedchain.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 730429b6..b63edd33 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -26,8 +26,6 @@ load_dotenv() ABS_PATH = os.getcwd() DB_DIR = os.path.join(ABS_PATH, "db") -memory = ConversationBufferMemory() - class EmbedChain: def __init__(self, config: BaseAppConfig): @@ -44,6 +42,7 @@ class EmbedChain: self.user_asks = [] self.is_docs_site_instance = False self.online = False + self.memory = ConversationBufferMemory() # Send anonymous telemetry self.s_id = self.config.id if self.config.id else str(uuid.uuid4()) @@ -362,8 +361,7 @@ class EmbedChain: k["web_search_result"] = self.access_search_and_get_results(input_query) contexts = self.retrieve_from_database(input_query, config) - global memory - chat_history = memory.load_memory_variables({})["history"] + chat_history = self.memory.load_memory_variables({})["history"] if chat_history: config.set_history(chat_history) @@ -376,14 +374,14 @@ class EmbedChain: answer = self.get_answer_from_llm(prompt, config) - memory.chat_memory.add_user_message(input_query) + self.memory.chat_memory.add_user_message(input_query) # Send anonymous telemetry thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("chat",)) thread_telemetry.start() if isinstance(answer, str): - memory.chat_memory.add_ai_message(answer) + self.memory.chat_memory.add_ai_message(answer) logging.info(f"Answer: {answer}") return answer else: @@ -395,7 +393,7 @@ class EmbedChain: for chunk in answer: streamed_answer = streamed_answer + chunk yield chunk - memory.chat_memory.add_ai_message(streamed_answer) + self.memory.chat_memory.add_ai_message(streamed_answer) logging.info(f"Answer: {streamed_answer}") def set_collection(self, collection_name):