feat: Use streaming setup at Query level (#214)

This commit is contained in:
aaishikdutta
2023-07-10 23:07:19 +05:30
committed by GitHub
parent 8674297d1a
commit c597b1939d
5 changed files with 41 additions and 38 deletions

View File

@@ -155,7 +155,7 @@ class EmbedChain:
prompt = template.substitute(context = context, query = input_query)
return prompt
def get_answer_from_llm(self, prompt):
def get_answer_from_llm(self, prompt, config: ChatConfig):
"""
Gets an answer based on the given query and context by passing it
to an LLM.
@@ -165,7 +165,7 @@ class EmbedChain:
:return: The answer.
"""
return self.get_llm_model_answer(prompt)
return self.get_llm_model_answer(prompt, config)
def query(self, input_query, config: QueryConfig = None):
"""
@@ -181,7 +181,7 @@ class EmbedChain:
config = QueryConfig()
context = self.retrieve_from_database(input_query)
prompt = self.generate_prompt(input_query, context, config.template)
answer = self.get_answer_from_llm(prompt)
answer = self.get_answer_from_llm(prompt, config)
return answer
def generate_chat_prompt(self, input_query, context, chat_history=''):
@@ -224,7 +224,7 @@ class EmbedChain:
context,
chat_history=chat_history,
)
answer = self.get_answer_from_llm(prompt)
answer = self.get_answer_from_llm(prompt, config)
memory.chat_memory.add_user_message(input_query)
if isinstance(answer, str):
memory.chat_memory.add_ai_message(answer)
@@ -295,14 +295,8 @@ class App(EmbedChain):
config = InitConfig()
super().__init__(config)
def get_llm_model_answer(self, prompt):
stream_response = self.config.stream_response
if stream_response:
return self._stream_llm_model_response(prompt)
else:
return self._get_llm_model_response(prompt)
def get_llm_model_answer(self, prompt, config: ChatConfig):
def _get_llm_model_response(self, prompt, stream_response = False):
messages = []
messages.append({
"role": "user", "content": prompt
@@ -313,20 +307,18 @@ class App(EmbedChain):
temperature=0,
max_tokens=1000,
top_p=1,
stream=stream_response
stream=config.stream
)
if stream_response:
# This contains the entire completions object. Needs to be sanitised
return response
if config.stream:
return self._stream_llm_model_response(response)
else:
return response["choices"][0]["message"]["content"]
def _stream_llm_model_response(self, prompt):
def _stream_llm_model_response(self, response):
"""
This is a generator for streaming response from the OpenAI completions API
"""
response = self._get_llm_model_response(prompt, True)
for line in response:
chunk = line['choices'][0].get('delta', {}).get('content', '')
yield chunk