feat: Use streaming setup at Query level (#214)
This commit is contained in:
@@ -4,5 +4,11 @@ class ChatConfig(QueryConfig):
|
||||
"""
|
||||
Config for the `chat` method, inherits from `QueryConfig`.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, stream: bool = False):
|
||||
"""
|
||||
Initializes the QueryConfig instance.
|
||||
|
||||
:param stream: Optional. Control if response is streamed back to the user
|
||||
:raises ValueError: If the template is not valid as template should contain $context and $query
|
||||
"""
|
||||
super().__init__(stream=stream)
|
||||
@@ -6,7 +6,7 @@ class InitConfig(BaseConfig):
|
||||
"""
|
||||
Config to initialize an embedchain `App` instance.
|
||||
"""
|
||||
def __init__(self, ef=None, db=None, stream_response=False):
|
||||
def __init__(self, ef=None, db=None):
|
||||
"""
|
||||
:param ef: Optional. Embedding function to use.
|
||||
:param db: Optional. (Vector) database to use for embeddings.
|
||||
@@ -27,10 +27,6 @@ class InitConfig(BaseConfig):
|
||||
self.db = ChromaDB(ef=self.ef)
|
||||
else:
|
||||
self.db = db
|
||||
|
||||
if not isinstance(stream_response, bool):
|
||||
raise ValueError("`stream_respone` should be bool")
|
||||
self.stream_response = stream_response
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -22,11 +22,12 @@ class QueryConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `query` method.
|
||||
"""
|
||||
def __init__(self, template: Template = None):
|
||||
def __init__(self, template: Template = None, stream: bool = False):
|
||||
"""
|
||||
Initializes the QueryConfig instance.
|
||||
|
||||
:param template: Optional. The `Template` instance to use as a template for prompt.
|
||||
:param stream: Optional. Control if response is streamed back to the user
|
||||
:raises ValueError: If the template is not valid as template should contain $context and $query
|
||||
"""
|
||||
if template is None:
|
||||
@@ -35,3 +36,7 @@ class QueryConfig(BaseConfig):
|
||||
and re.search(context_re, template.template)):
|
||||
raise ValueError("`template` should have `query` and `context` keys")
|
||||
self.template = template
|
||||
|
||||
if not isinstance(stream, bool):
|
||||
raise ValueError("`stream` should be bool")
|
||||
self.stream = stream
|
||||
|
||||
@@ -155,7 +155,7 @@ class EmbedChain:
|
||||
prompt = template.substitute(context = context, query = input_query)
|
||||
return prompt
|
||||
|
||||
def get_answer_from_llm(self, prompt):
|
||||
def get_answer_from_llm(self, prompt, config: ChatConfig):
|
||||
"""
|
||||
Gets an answer based on the given query and context by passing it
|
||||
to an LLM.
|
||||
@@ -165,7 +165,7 @@ class EmbedChain:
|
||||
:return: The answer.
|
||||
"""
|
||||
|
||||
return self.get_llm_model_answer(prompt)
|
||||
return self.get_llm_model_answer(prompt, config)
|
||||
|
||||
def query(self, input_query, config: QueryConfig = None):
|
||||
"""
|
||||
@@ -181,7 +181,7 @@ class EmbedChain:
|
||||
config = QueryConfig()
|
||||
context = self.retrieve_from_database(input_query)
|
||||
prompt = self.generate_prompt(input_query, context, config.template)
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
return answer
|
||||
|
||||
def generate_chat_prompt(self, input_query, context, chat_history=''):
|
||||
@@ -224,7 +224,7 @@ class EmbedChain:
|
||||
context,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
answer = self.get_answer_from_llm(prompt)
|
||||
answer = self.get_answer_from_llm(prompt, config)
|
||||
memory.chat_memory.add_user_message(input_query)
|
||||
if isinstance(answer, str):
|
||||
memory.chat_memory.add_ai_message(answer)
|
||||
@@ -295,14 +295,8 @@ class App(EmbedChain):
|
||||
config = InitConfig()
|
||||
super().__init__(config)
|
||||
|
||||
def get_llm_model_answer(self, prompt):
|
||||
stream_response = self.config.stream_response
|
||||
if stream_response:
|
||||
return self._stream_llm_model_response(prompt)
|
||||
else:
|
||||
return self._get_llm_model_response(prompt)
|
||||
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
||||
|
||||
def _get_llm_model_response(self, prompt, stream_response = False):
|
||||
messages = []
|
||||
messages.append({
|
||||
"role": "user", "content": prompt
|
||||
@@ -313,20 +307,18 @@ class App(EmbedChain):
|
||||
temperature=0,
|
||||
max_tokens=1000,
|
||||
top_p=1,
|
||||
stream=stream_response
|
||||
stream=config.stream
|
||||
)
|
||||
|
||||
if stream_response:
|
||||
# This contains the entire completions object. Needs to be sanitised
|
||||
return response
|
||||
if config.stream:
|
||||
return self._stream_llm_model_response(response)
|
||||
else:
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
def _stream_llm_model_response(self, prompt):
|
||||
def _stream_llm_model_response(self, response):
|
||||
"""
|
||||
This is a generator for streaming response from the OpenAI completions API
|
||||
"""
|
||||
response = self._get_llm_model_response(prompt, True)
|
||||
for line in response:
|
||||
chunk = line['choices'][0].get('delta', {}).get('content', '')
|
||||
yield chunk
|
||||
|
||||
Reference in New Issue
Block a user