Files
t6_mem0/embedchain/helpers/callbacks.py
2024-02-27 15:10:41 -08:00

74 lines
2.3 KiB
Python

import queue
from typing import Any, Union
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import LLMResult
STOP_ITEM = "[END]"
"""
This is a special item that is used to signal the end of the stream.
"""
class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
"""
This is a callback handler that yields the tokens as they are generated.
For a usage example, see the :func:`generate` function below.
"""
q: queue.Queue
"""
The queue to write the tokens to as they are generated.
"""
def __init__(self, q: queue.Queue) -> None:
"""
Initialize the callback handler.
q: The queue to write the tokens to as they are generated.
"""
super().__init__()
self.q = q
def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
"""Run when LLM starts running."""
with self.q.mutex:
self.q.queue.clear()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.q.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.q.put(STOP_ITEM)
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
"""Run when LLM errors."""
self.q.put("%s: %s" % (type(error).__name__, str(error)))
self.q.put(STOP_ITEM)
def generate(rq: queue.Queue):
"""
This is a generator that yields the items in the queue until it reaches the stop item.
Usage example:
```
def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
llm = OpenAI(streaming=True, callbacks=[callback_fn])
return llm.invoke(prompt="Write a poem about a tree.")
@app.route("/", methods=["GET"])
def generate_output():
q = Queue()
callback_fn = StreamingStdOutCallbackHandlerYield(q)
threading.Thread(target=askQuestion, args=(callback_fn,)).start()
return Response(generate(q), mimetype="text/event-stream")
```
"""
while True:
result: str = rq.get()
if result == STOP_ITEM or result is None:
break
yield result