feat: add streaming support for OpenAI models (#202)

This commit is contained in:
aaishikdutta
2023-07-10 17:18:24 +05:30
committed by GitHub
parent 13bac72e25
commit 66c4d30c60
3 changed files with 58 additions and 6 deletions

View File

@@ -164,8 +164,8 @@ class EmbedChain:
:param context: Similar documents to the query used as context.
:return: The answer.
"""
answer = self.get_llm_model_answer(prompt)
return answer
return self.get_llm_model_answer(prompt)
def query(self, input_query, config: QueryConfig = None):
"""
@@ -226,8 +226,20 @@ class EmbedChain:
)
answer = self.get_answer_from_llm(prompt)
memory.chat_memory.add_user_message(input_query)
memory.chat_memory.add_ai_message(answer)
return answer
if isinstance(answer, str):
memory.chat_memory.add_ai_message(answer)
return answer
else:
#this is a streamed response and needs to be handled differently
return self._stream_chat_response(answer)
def _stream_chat_response(self, answer):
streamed_answer = ""
for chunk in answer:
streamed_answer.join(chunk)
yield chunk
memory.chat_memory.add_ai_message(streamed_answer)
def dry_run(self, input_query, config: QueryConfig = None):
"""
@@ -284,6 +296,13 @@ class App(EmbedChain):
super().__init__(config)
def get_llm_model_answer(self, prompt):
stream_response = self.config.stream_response
if stream_response:
return self._stream_llm_model_response(prompt)
else:
return self._get_llm_model_response(prompt)
def _get_llm_model_response(self, prompt, stream_response = False):
messages = []
messages.append({
"role": "user", "content": prompt
@@ -294,8 +313,24 @@ class App(EmbedChain):
temperature=0,
max_tokens=1000,
top_p=1,
stream=stream_response
)
return response["choices"][0]["message"]["content"]
if stream_response:
# This contains the entire completions object. Needs to be sanitised
return response
else:
return response["choices"][0]["message"]["content"]
def _stream_llm_model_response(self, prompt):
"""
This is a generator for streaming response from the OpenAI completions API
"""
response = self._get_llm_model_response(prompt, True)
for line in response:
chunk = line['choices'][0].get('delta', {}).get('content', '')
yield chunk
class OpenSourceApp(EmbedChain):