[Feature] Add support for custom streaming callback (#971)
This commit is contained in:
@@ -10,7 +10,7 @@ from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.utils import validate_yaml_config
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Any
|
||||
from embedchain import Pipeline as App
|
||||
from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.vectordb.chroma import ChromaDB
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import signal
|
||||
import sys
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.add_config import ChunkerConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -3,7 +3,7 @@ from importlib import import_module
|
||||
from typing import Callable, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .base_app_config import BaseAppConfig
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseConfig(JSONSerializable):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import re
|
||||
from string import Template
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
Use the following pieces of context to answer the query at the end.
|
||||
@@ -68,6 +68,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
system_prompt: Optional[str] = None,
|
||||
where: Dict[str, Any] = None,
|
||||
query_type: Optional[str] = None,
|
||||
callbacks: Optional[List] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a configuration class instance for the LLM.
|
||||
@@ -98,6 +99,8 @@ class BaseLlmConfig(BaseConfig):
|
||||
:type system_prompt: Optional[str], optional
|
||||
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
||||
:type where: Dict[str, Any], optional
|
||||
:param callbacks: Langchain callback functions to use, defaults to None
|
||||
:type callbacks: Optional[List], optional
|
||||
:raises ValueError: If the template is not valid as template should
|
||||
contain $context and $query (and optionally $history)
|
||||
:raises ValueError: Stream is not boolean
|
||||
@@ -113,6 +116,7 @@ class BaseLlmConfig(BaseConfig):
|
||||
self.deployment_name = deployment_name
|
||||
self.system_prompt = system_prompt
|
||||
self.query_type = query_type
|
||||
self.callbacks = callbacks
|
||||
|
||||
if type(template) is str:
|
||||
template = Template(template)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
from .apps.base_app_config import BaseAppConfig
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.config.add_config import ChunkerConfig, LoaderConfig
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from embedchain.config.apps.base_app_config import BaseAppConfig
|
||||
from embedchain.constants import SQLITE_PATH
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import (DataType, DirectDataType,
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Any, Callable, Optional
|
||||
from embedchain.config.embedder.base import BaseEmbedderConfig
|
||||
|
||||
try:
|
||||
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
||||
from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
|
||||
except RuntimeError:
|
||||
from embedchain.utils import use_pysqlite3
|
||||
|
||||
use_pysqlite3()
|
||||
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction
|
||||
from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
|
||||
|
||||
|
||||
class EmbeddingFunc(EmbeddingFunction):
|
||||
|
||||
0
embedchain/helpers/__init__.py
Normal file
0
embedchain/helpers/__init__.py
Normal file
73
embedchain/helpers/callbacks.py
Normal file
73
embedchain/helpers/callbacks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import queue
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
STOP_ITEM = "[END]"
|
||||
"""
|
||||
This is a special item that is used to signal the end of the stream.
|
||||
"""
|
||||
|
||||
|
||||
class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
|
||||
"""
|
||||
This is a callback handler that yields the tokens as they are generated.
|
||||
For a usage example, see the :func:`generate` function below.
|
||||
"""
|
||||
|
||||
q: queue.Queue
|
||||
"""
|
||||
The queue to write the tokens to as they are generated.
|
||||
"""
|
||||
|
||||
def __init__(self, q: queue.Queue) -> None:
|
||||
"""
|
||||
Initialize the callback handler.
|
||||
q: The queue to write the tokens to as they are generated.
|
||||
"""
|
||||
super().__init__()
|
||||
self.q = q
|
||||
|
||||
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
with self.q.mutex:
|
||||
self.q.queue.clear()
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
self.q.put(token)
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running."""
|
||||
self.q.put(STOP_ITEM)
|
||||
|
||||
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
|
||||
"""Run when LLM errors."""
|
||||
self.q.put("%s: %s" % (type(error).__name__, str(error)))
|
||||
self.q.put(STOP_ITEM)
|
||||
|
||||
|
||||
def generate(rq: queue.Queue):
|
||||
"""
|
||||
This is a generator that yields the items in the queue until it reaches the stop item.
|
||||
|
||||
Usage example:
|
||||
```
|
||||
def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
|
||||
llm = OpenAI(streaming=True, callbacks=[callback_fn])
|
||||
return llm(prompt="Write a poem about a tree.")
|
||||
|
||||
@app.route("/", methods=["GET"])
|
||||
def generate_output():
|
||||
q = Queue()
|
||||
callback_fn = StreamingStdOutCallbackHandlerYield(q)
|
||||
threading.Thread(target=askQuestion, args=(callback_fn,)).start()
|
||||
return Response(generate(q), mimetype="text/event-stream")
|
||||
```
|
||||
"""
|
||||
while True:
|
||||
result: str = rq.get()
|
||||
if result == STOP_ITEM or result is None:
|
||||
break
|
||||
yield result
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from embedchain.config import BaseLlmConfig
|
||||
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
||||
DOCS_SITE_PROMPT_TEMPLATE)
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
from embedchain.memory.base import ECChatMemory
|
||||
from embedchain.memory.message import ChatMessage
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from langchain.llms import Cohere
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from langchain.llms import HuggingFaceHub
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain.chat_models import JinaChat
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
from langchain.llms import Replicate
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
@@ -34,7 +34,8 @@ class OpenAILlm(BaseLlm):
|
||||
from langchain.callbacks.streaming_stdout import \
|
||||
StreamingStdOutCallbackHandler
|
||||
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()])
|
||||
callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
|
||||
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
|
||||
else:
|
||||
chat = ChatOpenAI(**kwargs)
|
||||
return chat(messages).content
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config import BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseLoader(JSONSerializable):
|
||||
|
||||
@@ -12,7 +12,7 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ except ImportError:
|
||||
raise ImportError(
|
||||
'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ except ImportError:
|
||||
raise ImportError(
|
||||
'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ except ImportError:
|
||||
'Sitemap requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
from embedchain.utils import is_readable
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
|
||||
import requests
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import is_readable
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ except ImportError:
|
||||
'Webpage requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ except ImportError:
|
||||
raise ImportError(
|
||||
'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ except ImportError:
|
||||
raise ImportError(
|
||||
'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
|
||||
) from None
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseMessage(JSONSerializable):
|
||||
|
||||
@@ -15,7 +15,7 @@ from embedchain.embedchain import EmbedChain
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.embedder.openai import OpenAIEmbedder
|
||||
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.llm.base import BaseLlm
|
||||
from embedchain.llm.openai import OpenAILlm
|
||||
from embedchain.telemetry.posthog import AnonymousTelemetry
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from embedchain.config.vectordb.base import BaseVectorDbConfig
|
||||
from embedchain.embedder.base import BaseEmbedder
|
||||
from embedchain.helper.json_serializable import JSONSerializable
|
||||
from embedchain.helpers.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseVectorDB(JSONSerializable):
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain.docstore.document import Document
|
||||
from tqdm import tqdm
|
||||
|
||||
from embedchain.config import ChromaDbConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
try:
|
||||
|
||||
@@ -10,7 +10,7 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
from embedchain.config import ElasticsearchDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import OpenSearchVectorSearch
|
||||
|
||||
from embedchain.config import OpenSearchDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
from embedchain.config.vectordb.pinecone import PineconeDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
from embedchain.config.vectordb.weaviate import WeaviateDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from embedchain.config import ZillizDBConfig
|
||||
from embedchain.helper.json_serializable import register_deserializable
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base import BaseVectorDB
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "embedchain"
|
||||
version = "0.1.18"
|
||||
version = "0.1.19"
|
||||
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
|
||||
authors = [
|
||||
"Taranjeet Singh <taranjeet@embedchain.ai>",
|
||||
|
||||
@@ -4,8 +4,8 @@ from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
from embedchain.helpers.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
|
||||
|
||||
class TestJsonSerializable(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user