diff --git a/README.md b/README.md index 196968fd..5696fb75 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,19 @@ from embedchain import PersonApp as ECPApp print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?")) # answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality. ``` +### Stream Response + +- You can add config to your query method to stream responses like ChatGPT does. You would require a downstream handler to render the chunk in your desirable format + +- To use this, instantiate App with a `InitConfig` instance passing `stream_response=True`. The following example iterates through the chunks and prints them as they appear +```python +app = App(InitConfig(stream_response=True)) +resp = naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?") + +for chunk in resp: + print(chunk, end="", flush=True) +# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality. +``` ### Chat Interface diff --git a/embedchain/config/InitConfig.py b/embedchain/config/InitConfig.py index 9bcaeb90..769ee8fe 100644 --- a/embedchain/config/InitConfig.py +++ b/embedchain/config/InitConfig.py @@ -6,7 +6,7 @@ class InitConfig(BaseConfig): """ Config to initialize an embedchain `App` instance. """ - def __init__(self, ef=None, db=None): + def __init__(self, ef=None, db=None, stream_response=False): """ :param ef: Optional. Embedding function to use. :param db: Optional. (Vector) database to use for embeddings. @@ -27,6 +27,10 @@ class InitConfig(BaseConfig): self.db = ChromaDB(ef=self.ef) else: self.db = db + + if not isinstance(stream_response, bool): + raise ValueError("`stream_respone` should be bool") + self.stream_response = stream_response return diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index c944ecfe..1c4b7c0d 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -164,8 +164,8 @@ class EmbedChain: :param context: Similar documents to the query used as context. :return: The answer. """ - answer = self.get_llm_model_answer(prompt) - return answer + + return self.get_llm_model_answer(prompt) def query(self, input_query, config: QueryConfig = None): """ @@ -226,8 +226,20 @@ class EmbedChain: ) answer = self.get_answer_from_llm(prompt) memory.chat_memory.add_user_message(input_query) - memory.chat_memory.add_ai_message(answer) - return answer + if isinstance(answer, str): + memory.chat_memory.add_ai_message(answer) + return answer + else: + #this is a streamed response and needs to be handled differently + return self._stream_chat_response(answer) + + def _stream_chat_response(self, answer): + streamed_answer = "" + for chunk in answer: + streamed_answer.join(chunk) + yield chunk + memory.chat_memory.add_ai_message(streamed_answer) + def dry_run(self, input_query, config: QueryConfig = None): """ @@ -284,6 +296,13 @@ class App(EmbedChain): super().__init__(config) def get_llm_model_answer(self, prompt): + stream_response = self.config.stream_response + if stream_response: + return self._stream_llm_model_response(prompt) + else: + return self._get_llm_model_response(prompt) + + def _get_llm_model_response(self, prompt, stream_response = False): messages = [] messages.append({ "role": "user", "content": prompt @@ -294,8 +313,24 @@ class App(EmbedChain): temperature=0, max_tokens=1000, top_p=1, + stream=stream_response ) - return response["choices"][0]["message"]["content"] + + if stream_response: + # This contains the entire completions object. Needs to be sanitised + return response + else: + return response["choices"][0]["message"]["content"] + + def _stream_llm_model_response(self, prompt): + """ + This is a generator for streaming response from the OpenAI completions API + """ + response = self._get_llm_model_response(prompt, True) + for line in response: + chunk = line['choices'][0].get('delta', {}).get('content', '') + yield chunk + class OpenSourceApp(EmbedChain):