From 4335fff153aa4b220cc2376e5819cddb7b70bf4d Mon Sep 17 00:00:00 2001 From: aaishikdutta <107566376+aaishikdutta@users.noreply.github.com> Date: Fri, 14 Jul 2023 12:04:15 +0530 Subject: [PATCH] bug: Fix/stream logging (#262) --- embedchain/embedchain.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index f0b783a7..fafc2aa0 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -53,7 +53,9 @@ class EmbedChain: data_formatter = DataFormatter(data_type, config) self.user_asks.append([data_type, url, metadata]) - self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata) + self.load_and_embed( + data_formatter.loader, data_formatter.chunker, url, metadata + ) def add_local(self, data_type, content, metadata=None, config: AddConfig = None): """ @@ -117,10 +119,12 @@ class EmbedChain: chunks_before_addition = self.count() - # Add metadata to each document + # Add metadata to each document metadatas_with_metadata = [meta or metadata for meta in metadatas] - self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) + self.collection.add( + documents=documents, metadatas=list(metadatas_with_metadata), ids=ids + ) print( ( f"Successfully saved {src}. New chunks count: " @@ -210,9 +214,21 @@ class EmbedChain: contexts = self.retrieve_from_database(input_query, config) prompt = self.generate_prompt(input_query, contexts, config) logging.info(f"Prompt: {prompt}") + answer = self.get_answer_from_llm(prompt, config) - logging.info(f"Answer: {answer}") - return answer + + if isinstance(answer, str): + logging.info(f"Answer: {answer}") + return answer + else: + return self._stream_query_response(answer) + + def _stream_query_response(self, answer): + streamed_answer = "" + for chunk in answer: + streamed_answer = streamed_answer + chunk + yield chunk + logging.info(f"Answer: {streamed_answer}") def chat(self, input_query, config: ChatConfig = None): """ @@ -254,7 +270,7 @@ class EmbedChain: def _stream_chat_response(self, answer): streamed_answer = "" for chunk in answer: - streamed_answer.join(chunk) + streamed_answer = streamed_answer + chunk yield chunk memory.chat_memory.add_ai_message(streamed_answer) logging.info(f"Answer: {streamed_answer}")