[Feature] Add support for custom streaming callback (#971)
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from langchain.llms import Cohere
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from langchain.llms import HuggingFaceHub
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain.chat_models import JinaChat
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from langchain.llms import Replicate
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
@@ -34,7 +34,8 @@ class OpenAILlm(BaseLlm):
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
|
||||
else:
|
||||
chat = ChatOpenAI(**kwargs)
|
||||
return chat(messages).content
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user