[Feature] Add support for custom streaming callback (#971)

This commit is contained in:
Deshraj Yadav
2023-11-22 01:06:33 -08:00
committed by GitHub
parent 798d3fcc5a
commit f6b80e01a1
82 changed files with 162 additions and 84 deletions

View File

@@ -10,7 +10,7 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.utils import validate_yaml_config from embedchain.utils import validate_yaml_config

View File

@@ -3,8 +3,8 @@ from typing import Any
from embedchain import Pipeline as App from embedchain import Pipeline as App
from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.helper.json_serializable import (JSONSerializable, from embedchain.helpers.json_serializable import (JSONSerializable,
register_deserializable) register_deserializable)
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.vectordb.chroma import ChromaDB from embedchain.vectordb.chroma import ChromaDB

View File

@@ -2,7 +2,7 @@ import argparse
import logging import logging
import os import os
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot from .base import BaseBot

View File

@@ -3,7 +3,7 @@ import logging
import os import os
from typing import List, Optional from typing import List, Optional
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot from .base import BaseBot

View File

@@ -5,7 +5,7 @@ import signal
import sys import sys
from embedchain import App from embedchain import App
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot from .base import BaseBot

View File

@@ -4,7 +4,7 @@ import logging
import signal import signal
import sys import sys
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from .base import BaseBot from .base import BaseBot

View File

@@ -1,6 +1,6 @@
import hashlib import hashlib
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.add_config import ChunkerConfig from embedchain.config.add_config import ChunkerConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -3,7 +3,7 @@ from importlib import import_module
from typing import Callable, Optional from typing import Callable, Optional
from embedchain.config.base_config import BaseConfig from embedchain.config.base_config import BaseConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from .base_app_config import BaseAppConfig from .base_app_config import BaseAppConfig

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional from typing import Optional
from embedchain.config.base_config import BaseConfig from embedchain.config.base_config import BaseConfig
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -1,6 +1,6 @@
from typing import Any, Dict from typing import Any, Dict
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
class BaseConfig(JSONSerializable): class BaseConfig(JSONSerializable):

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -1,9 +1,9 @@
import re import re
from string import Template from string import Template
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from embedchain.config.base_config import BaseConfig from embedchain.config.base_config import BaseConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
DEFAULT_PROMPT = """ DEFAULT_PROMPT = """
Use the following pieces of context to answer the query at the end. Use the following pieces of context to answer the query at the end.
@@ -68,6 +68,7 @@ class BaseLlmConfig(BaseConfig):
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
where: Dict[str, Any] = None, where: Dict[str, Any] = None,
query_type: Optional[str] = None, query_type: Optional[str] = None,
callbacks: Optional[List] = None,
): ):
""" """
Initializes a configuration class instance for the LLM. Initializes a configuration class instance for the LLM.
@@ -98,6 +99,8 @@ class BaseLlmConfig(BaseConfig):
:type system_prompt: Optional[str], optional :type system_prompt: Optional[str], optional
:param where: A dictionary of key-value pairs to filter the database results., defaults to None :param where: A dictionary of key-value pairs to filter the database results., defaults to None
:type where: Dict[str, Any], optional :type where: Dict[str, Any], optional
:param callbacks: Langchain callback functions to use, defaults to None
:type callbacks: Optional[List], optional
:raises ValueError: If the template is not valid as template should :raises ValueError: If the template is not valid as template should
contain $context and $query (and optionally $history) contain $context and $query (and optionally $history)
:raises ValueError: Stream is not boolean :raises ValueError: Stream is not boolean
@@ -113,6 +116,7 @@ class BaseLlmConfig(BaseConfig):
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.query_type = query_type self.query_type = query_type
self.callbacks = callbacks
if type(template) is str: if type(template) is str:
template = Template(template) template = Template(template)

View File

@@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from .apps.base_app_config import BaseAppConfig from .apps.base_app_config import BaseAppConfig

View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -2,7 +2,7 @@ import os
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional from typing import Dict, Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -2,7 +2,7 @@ import os
from typing import Optional from typing import Optional
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
@register_deserializable @register_deserializable

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict
from embedchain.chunkers.base_chunker import BaseChunker from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config import AddConfig from embedchain.config import AddConfig
from embedchain.config.add_config import ChunkerConfig, LoaderConfig from embedchain.config.add_config import ChunkerConfig, LoaderConfig
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType from embedchain.models.data_type import DataType

View File

@@ -13,7 +13,7 @@ from embedchain.config.apps.base_app_config import BaseAppConfig
from embedchain.constants import SQLITE_PATH from embedchain.constants import SQLITE_PATH
from embedchain.data_formatter import DataFormatter from embedchain.data_formatter import DataFormatter
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import (DataType, DirectDataType, from embedchain.models.data_type import (DataType, DirectDataType,

View File

@@ -3,12 +3,12 @@ from typing import Any, Callable, Optional
from embedchain.config.embedder.base import BaseEmbedderConfig from embedchain.config.embedder.base import BaseEmbedderConfig
try: try:
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
except RuntimeError: except RuntimeError:
from embedchain.utils import use_pysqlite3 from embedchain.utils import use_pysqlite3
use_pysqlite3() use_pysqlite3()
from chromadb.api.types import Embeddings, Embeddable, EmbeddingFunction from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings
class EmbeddingFunc(EmbeddingFunction): class EmbeddingFunc(EmbeddingFunction):

View File

View File

@@ -0,0 +1,73 @@
import queue
from typing import Any, Dict, List, Union
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import LLMResult
STOP_ITEM = "[END]"
"""
This is a special item that is used to signal the end of the stream.
"""
class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
"""
This is a callback handler that yields the tokens as they are generated.
For a usage example, see the :func:`generate` function below.
"""
q: queue.Queue
"""
The queue to write the tokens to as they are generated.
"""
def __init__(self, q: queue.Queue) -> None:
"""
Initialize the callback handler.
q: The queue to write the tokens to as they are generated.
"""
super().__init__()
self.q = q
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
"""Run when LLM starts running."""
with self.q.mutex:
self.q.queue.clear()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.q.put(token)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.q.put(STOP_ITEM)
def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
"""Run when LLM errors."""
self.q.put("%s: %s" % (type(error).__name__, str(error)))
self.q.put(STOP_ITEM)
def generate(rq: queue.Queue):
"""
This is a generator that yields the items in the queue until it reaches the stop item.
Usage example:
```
def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
llm = OpenAI(streaming=True, callbacks=[callback_fn])
return llm(prompt="Write a poem about a tree.")
@app.route("/", methods=["GET"])
def generate_output():
q = Queue()
callback_fn = StreamingStdOutCallbackHandlerYield(q)
threading.Thread(target=askQuestion, args=(callback_fn,)).start()
return Response(generate(q), mimetype="text/event-stream")
```
"""
while True:
result: str = rq.get()
if result == STOP_ITEM or result is None:
break
yield result

View File

@@ -3,7 +3,7 @@ import os
from typing import Optional from typing import Optional
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -2,7 +2,7 @@ import logging
from typing import Optional from typing import Optional
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -7,7 +7,7 @@ from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (DEFAULT_PROMPT, from embedchain.config.llm.base import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE) DOCS_SITE_PROMPT_TEMPLATE)
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ECChatMemory from embedchain.memory.base import ECChatMemory
from embedchain.memory.message import ChatMessage from embedchain.memory.message import ChatMessage

View File

@@ -5,7 +5,7 @@ from typing import Optional
from langchain.llms import Cohere from langchain.llms import Cohere
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -4,7 +4,7 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -5,7 +5,7 @@ from typing import Optional
from langchain.llms import HuggingFaceHub from langchain.llms import HuggingFaceHub
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -5,7 +5,7 @@ from langchain.chat_models import JinaChat
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -5,7 +5,7 @@ from typing import Optional
from langchain.llms import Replicate from langchain.llms import Replicate
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -4,7 +4,7 @@ from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
@@ -34,7 +34,8 @@ class OpenAILlm(BaseLlm):
from langchain.callbacks.streaming_stdout import \ from langchain.callbacks.streaming_stdout import \
StreamingStdOutCallbackHandler StreamingStdOutCallbackHandler
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=[StreamingStdOutCallbackHandler()]) callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks)
else: else:
chat = ChatOpenAI(**kwargs) chat = ChatOpenAI(**kwargs)
return chat(messages).content return chat(messages).content

View File

@@ -3,7 +3,7 @@ import logging
from typing import Optional from typing import Optional
from embedchain.config import BaseLlmConfig from embedchain.config import BaseLlmConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm

View File

@@ -1,4 +1,4 @@
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
class BaseLoader(JSONSerializable): class BaseLoader(JSONSerializable):

View File

@@ -12,7 +12,7 @@ except ImportError:
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader

View File

@@ -6,7 +6,7 @@ except ImportError:
raise ImportError( raise ImportError(
'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' 'Docx file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader

View File

@@ -1,6 +1,6 @@
import hashlib import hashlib
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader

View File

@@ -1,6 +1,6 @@
import hashlib import hashlib
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader

View File

@@ -1,6 +1,6 @@
import hashlib import hashlib
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader

View File

@@ -10,7 +10,7 @@ except ImportError:
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string from embedchain.utils import clean_string

View File

@@ -6,7 +6,7 @@ except ImportError:
raise ImportError( raise ImportError(
'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' 'PDF File requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string from embedchain.utils import clean_string

View File

@@ -13,7 +13,7 @@ except ImportError:
'Sitemap requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' 'Sitemap requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.web_page import WebPageLoader from embedchain.loaders.web_page import WebPageLoader
from embedchain.utils import is_readable from embedchain.utils import is_readable

View File

@@ -4,7 +4,7 @@ import time
import requests import requests
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import is_readable from embedchain.utils import is_readable

View File

@@ -1,6 +1,6 @@
import hashlib import hashlib
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string from embedchain.utils import clean_string

View File

@@ -10,7 +10,7 @@ except ImportError:
'Webpage requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' 'Webpage requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string from embedchain.utils import clean_string

View File

@@ -6,7 +6,7 @@ except ImportError:
raise ImportError( raise ImportError(
'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' 'XML file requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string from embedchain.utils import clean_string

View File

@@ -6,7 +6,7 @@ except ImportError:
raise ImportError( raise ImportError(
'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`' 'YouTube video requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
) from None ) from None
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string from embedchain.utils import clean_string

View File

@@ -1,7 +1,7 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
class BaseMessage(JSONSerializable): class BaseMessage(JSONSerializable):

View File

@@ -15,7 +15,7 @@ from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.embedder.openai import OpenAIEmbedder from embedchain.embedder.openai import OpenAIEmbedder
from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.llm.base import BaseLlm from embedchain.llm.base import BaseLlm
from embedchain.llm.openai import OpenAILlm from embedchain.llm.openai import OpenAILlm
from embedchain.telemetry.posthog import AnonymousTelemetry from embedchain.telemetry.posthog import AnonymousTelemetry

View File

@@ -1,6 +1,6 @@
from embedchain.config.vectordb.base import BaseVectorDbConfig from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.embedder.base import BaseEmbedder from embedchain.embedder.base import BaseEmbedder
from embedchain.helper.json_serializable import JSONSerializable from embedchain.helpers.json_serializable import JSONSerializable
class BaseVectorDB(JSONSerializable): class BaseVectorDB(JSONSerializable):

View File

@@ -6,7 +6,7 @@ from langchain.docstore.document import Document
from tqdm import tqdm from tqdm import tqdm
from embedchain.config import ChromaDbConfig from embedchain.config import ChromaDbConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
try: try:

View File

@@ -10,7 +10,7 @@ except ImportError:
) from None ) from None
from embedchain.config import ElasticsearchDBConfig from embedchain.config import ElasticsearchDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -16,7 +16,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import OpenSearchVectorSearch from langchain.vectorstores import OpenSearchVectorSearch
from embedchain.config import OpenSearchDBConfig from embedchain.config import OpenSearchDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -9,7 +9,7 @@ except ImportError:
) from None ) from None
from embedchain.config.vectordb.pinecone import PineconeDBConfig from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -10,7 +10,7 @@ except ImportError:
) from None ) from None
from embedchain.config.vectordb.weaviate import WeaviateDBConfig from embedchain.config.vectordb.weaviate import WeaviateDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB

View File

@@ -2,7 +2,7 @@ import logging
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from embedchain.config import ZillizDBConfig from embedchain.config import ZillizDBConfig
from embedchain.helper.json_serializable import register_deserializable from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB from embedchain.vectordb.base import BaseVectorDB
try: try:

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "embedchain" name = "embedchain"
version = "0.1.18" version = "0.1.19"
description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data" description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
authors = [ authors = [
"Taranjeet Singh <taranjeet@embedchain.ai>", "Taranjeet Singh <taranjeet@embedchain.ai>",

View File

@@ -4,8 +4,8 @@ from string import Template
from embedchain import App from embedchain import App
from embedchain.config import AppConfig, BaseLlmConfig from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.helper.json_serializable import (JSONSerializable, from embedchain.helpers.json_serializable import (JSONSerializable,
register_deserializable) register_deserializable)
class TestJsonSerializable(unittest.TestCase): class TestJsonSerializable(unittest.TestCase):