[Feature] Add support for custom streaming callback (#971)

This commit is contained in:
Deshraj Yadav
2023-11-22 01:06:33 -08:00
committed by GitHub
parent 798d3fcc5a
commit f6b80e01a1
82 changed files with 162 additions and 84 deletions

View File

@@ -3,7 +3,7 @@ import os
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm

View File

@@ -7,7 +7,7 @@ from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helper.json_serializable import JSONSerializable
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ECChatMemory
from embedchain.memory.message import ChatMessage

View File

@@ -5,7 +5,7 @@ from typing import Optional
from langchain.llms import Cohere
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm

View File

@@ -4,7 +4,7 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm

View File

@@ -5,7 +5,7 @@ from typing import Optional
from langchain.llms import HuggingFaceHub
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm

View File

@@ -5,7 +5,7 @@ from langchain.chat_models import JinaChat
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm

View File

@@ -5,7 +5,7 @@ from typing import Optional
from langchain.llms import Replicate
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm

View File

@@ -4,7 +4,7 @@ from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm
@@ -34,7 +34,8 @@ class OpenAILlm(BaseLlm):
from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
else:
chat = ChatOpenAI(**kwargs)
return chat(messages).content

View File

@@ -3,7 +3,7 @@ import logging
from typing import Optional
from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm