Feat/serialize deserialize (#508)
Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
This commit is contained in:
@@ -4,8 +4,10 @@ import openai
|
|||||||
|
|
||||||
from embedchain.config import AppConfig, ChatConfig
|
from embedchain.config import AppConfig, ChatConfig
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class App(EmbedChain):
|
class App(EmbedChain):
|
||||||
"""
|
"""
|
||||||
The EmbedChain app.
|
The EmbedChain app.
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ from langchain.schema import BaseMessage
|
|||||||
|
|
||||||
from embedchain.config import ChatConfig, CustomAppConfig
|
from embedchain.config import ChatConfig, CustomAppConfig
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.models import Providers
|
from embedchain.models import Providers
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class CustomApp(EmbedChain):
|
class CustomApp(EmbedChain):
|
||||||
"""
|
"""
|
||||||
The custom EmbedChain app.
|
The custom EmbedChain app.
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ from typing import Iterable, Union, Optional
|
|||||||
|
|
||||||
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
||||||
from embedchain.embedchain import EmbedChain
|
from embedchain.embedchain import EmbedChain
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
gpt4all_model = None
|
gpt4all_model = None
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class OpenSourceApp(EmbedChain):
|
class OpenSourceApp(EmbedChain):
|
||||||
"""
|
"""
|
||||||
The OpenSource app.
|
The OpenSource app.
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ from embedchain.config import ChatConfig, QueryConfig
|
|||||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||||
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
||||||
DEFAULT_PROMPT_WITH_HISTORY)
|
DEFAULT_PROMPT_WITH_HISTORY)
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class EmbedChainPersonApp:
|
class EmbedChainPersonApp:
|
||||||
"""
|
"""
|
||||||
Base class to create a person bot.
|
Base class to create a person bot.
|
||||||
@@ -50,6 +52,7 @@ class EmbedChainPersonApp:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class PersonApp(EmbedChainPersonApp, App):
|
class PersonApp(EmbedChainPersonApp, App):
|
||||||
"""
|
"""
|
||||||
The Person app.
|
The Person app.
|
||||||
@@ -65,6 +68,7 @@ class PersonApp(EmbedChainPersonApp, App):
|
|||||||
return super().chat(input_query, config, dry_run)
|
return super().chat(input_query, config, dry_run)
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
||||||
"""
|
"""
|
||||||
The Person app.
|
The Person app.
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
from embedchain import CustomApp
|
from embedchain import CustomApp
|
||||||
from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
|
from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import (
|
||||||
|
JSONSerializable, register_deserializable)
|
||||||
from embedchain.models import EmbeddingFunctions, Providers
|
from embedchain.models import EmbeddingFunctions, Providers
|
||||||
|
|
||||||
|
|
||||||
class BaseBot:
|
@register_deserializable
|
||||||
|
class BaseBot(JSONSerializable):
|
||||||
def __init__(self, app_config=None):
|
def __init__(self, app_config=None):
|
||||||
if app_config is None:
|
if app_config is None:
|
||||||
app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)
|
app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ from typing import List, Optional
|
|||||||
from fastapi_poe import PoeBot, run
|
from fastapi_poe import PoeBot, run
|
||||||
|
|
||||||
from embedchain.config import QueryConfig
|
from embedchain.config import QueryConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
from .base import BaseBot
|
from .base import BaseBot
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class EcPoeBot(BaseBot, PoeBot):
|
class EcPoeBot(BaseBot, PoeBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.history_length = 5
|
self.history_length = 5
|
||||||
|
|||||||
@@ -6,9 +6,12 @@ import sys
|
|||||||
from flask import Flask, request
|
from flask import Flask, request
|
||||||
from twilio.twiml.messaging_response import MessagingResponse
|
from twilio.twiml.messaging_response import MessagingResponse
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
from .base import BaseBot
|
from .base import BaseBot
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class WhatsAppBot(BaseBot):
|
class WhatsAppBot(BaseBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
|
||||||
|
|
||||||
class BaseChunker:
|
class BaseChunker(JSONSerializable):
|
||||||
def __init__(self, text_splitter):
|
def __init__(self, text_splitter):
|
||||||
"""Initialize the chunker."""
|
"""Initialize the chunker."""
|
||||||
self.text_splitter = text_splitter
|
self.text_splitter = text_splitter
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class DocsSiteChunker(BaseChunker):
|
class DocsSiteChunker(BaseChunker):
|
||||||
"""Chunker for code docs site."""
|
"""Chunker for code docs site."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class DocxFileChunker(BaseChunker):
|
class DocxFileChunker(BaseChunker):
|
||||||
"""Chunker for .docx file."""
|
"""Chunker for .docx file."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class NotionChunker(BaseChunker):
|
class NotionChunker(BaseChunker):
|
||||||
"""Chunker for notion."""
|
"""Chunker for notion."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class PdfFileChunker(BaseChunker):
|
class PdfFileChunker(BaseChunker):
|
||||||
"""Chunker for PDF file."""
|
"""Chunker for PDF file."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class QnaPairChunker(BaseChunker):
|
class QnaPairChunker(BaseChunker):
|
||||||
"""Chunker for QnA pair."""
|
"""Chunker for QnA pair."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class TextChunker(BaseChunker):
|
class TextChunker(BaseChunker):
|
||||||
"""Chunker for text."""
|
"""Chunker for text."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class WebPageChunker(BaseChunker):
|
class WebPageChunker(BaseChunker):
|
||||||
"""Chunker for web page."""
|
"""Chunker for web page."""
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|||||||
|
|
||||||
from embedchain.chunkers.base_chunker import BaseChunker
|
from embedchain.chunkers.base_chunker import BaseChunker
|
||||||
from embedchain.config.AddConfig import ChunkerConfig
|
from embedchain.config.AddConfig import ChunkerConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class YoutubeVideoChunker(BaseChunker):
|
class YoutubeVideoChunker(BaseChunker):
|
||||||
"""Chunker for Youtube video."""
|
"""Chunker for Youtube video."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from embedchain.config.BaseConfig import BaseConfig
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class ChunkerConfig(BaseConfig):
|
class ChunkerConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Config for the chunker used in `add` method
|
Config for the chunker used in `add` method
|
||||||
@@ -19,6 +21,7 @@ class ChunkerConfig(BaseConfig):
|
|||||||
self.length_function = length_function if length_function else len
|
self.length_function = length_function if length_function else len
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class LoaderConfig(BaseConfig):
|
class LoaderConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Config for the chunker used in `add` method
|
Config for the chunker used in `add` method
|
||||||
@@ -28,6 +31,7 @@ class LoaderConfig(BaseConfig):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class AddConfig(BaseConfig):
|
class AddConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Config for the `add` method.
|
Config for the `add` method.
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
class BaseConfig:
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfig(JSONSerializable):
|
||||||
"""
|
"""
|
||||||
Base config.
|
Base config.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from string import Template
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.QueryConfig import QueryConfig
|
from embedchain.config.QueryConfig import QueryConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
DEFAULT_PROMPT = """
|
DEFAULT_PROMPT = """
|
||||||
You are a chatbot having a conversation with a human. You are given chat
|
You are a chatbot having a conversation with a human. You are given chat
|
||||||
@@ -20,6 +21,7 @@ DEFAULT_PROMPT = """
|
|||||||
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
|
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class ChatConfig(QueryConfig):
|
class ChatConfig(QueryConfig):
|
||||||
"""
|
"""
|
||||||
Config for the `chat` method, inherits from `QueryConfig`.
|
Config for the `chat` method, inherits from `QueryConfig`.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from string import Template
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from embedchain.config.BaseConfig import BaseConfig
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
DEFAULT_PROMPT = """
|
DEFAULT_PROMPT = """
|
||||||
Use the following pieces of context to answer the query at the end.
|
Use the following pieces of context to answer the query at the end.
|
||||||
@@ -48,6 +49,7 @@ context_re = re.compile(r"\$\{*context\}*")
|
|||||||
history_re = re.compile(r"\$\{*history\}*")
|
history_re = re.compile(r"\$\{*history\}*")
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class QueryConfig(BaseConfig):
|
class QueryConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Config for the `query` method.
|
Config for the `query` method.
|
||||||
|
|||||||
@@ -9,9 +9,12 @@ except RuntimeError:
|
|||||||
use_pysqlite3()
|
use_pysqlite3()
|
||||||
from chromadb.utils import embedding_functions
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
from .BaseAppConfig import BaseAppConfig
|
from .BaseAppConfig import BaseAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class AppConfig(BaseAppConfig):
|
class AppConfig(BaseAppConfig):
|
||||||
"""
|
"""
|
||||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ import logging
|
|||||||
|
|
||||||
from embedchain.config.BaseConfig import BaseConfig
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
from embedchain.models import VectorDatabases, VectorDimensions
|
from embedchain.models import VectorDatabases, VectorDimensions
|
||||||
|
|
||||||
|
|
||||||
class BaseAppConfig(BaseConfig):
|
class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||||
"""
|
"""
|
||||||
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from chromadb.api.types import Documents, Embeddings
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
|
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
|
||||||
VectorDimensions)
|
VectorDimensions)
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ from .BaseAppConfig import BaseAppConfig
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class CustomAppConfig(BaseAppConfig):
|
class CustomAppConfig(BaseAppConfig):
|
||||||
"""
|
"""
|
||||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ from typing import Optional
|
|||||||
|
|
||||||
from chromadb.utils import embedding_functions
|
from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
from .BaseAppConfig import BaseAppConfig
|
from .BaseAppConfig import BaseAppConfig
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class OpenSourceAppConfig(BaseAppConfig):
|
class OpenSourceAppConfig(BaseAppConfig):
|
||||||
"""
|
"""
|
||||||
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
|
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from embedchain.config.BaseConfig import BaseConfig
|
from embedchain.config.BaseConfig import BaseConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class ElasticsearchDBConfig(BaseConfig):
|
class ElasticsearchDBConfig(BaseConfig):
|
||||||
"""
|
"""
|
||||||
Config to initialize an elasticsearch client.
|
Config to initialize an elasticsearch client.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from embedchain.chunkers.text import TextChunker
|
|||||||
from embedchain.chunkers.web_page import WebPageChunker
|
from embedchain.chunkers.web_page import WebPageChunker
|
||||||
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
||||||
from embedchain.config import AddConfig
|
from embedchain.config import AddConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||||
from embedchain.loaders.docx_file import DocxFileLoader
|
from embedchain.loaders.docx_file import DocxFileLoader
|
||||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||||
@@ -18,7 +19,7 @@ from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
|||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
|
|
||||||
|
|
||||||
class DataFormatter:
|
class DataFormatter(JSONSerializable):
|
||||||
"""
|
"""
|
||||||
DataFormatter is an internal utility class which abstracts the mapping for
|
DataFormatter is an internal utility class which abstracts the mapping for
|
||||||
loaders and chunkers to the data_type entered by the user in their
|
loaders and chunkers to the data_type entered by the user in their
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from embedchain.config import AddConfig, ChatConfig, QueryConfig
|
|||||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||||
from embedchain.data_formatter import DataFormatter
|
from embedchain.data_formatter import DataFormatter
|
||||||
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.models.data_type import DataType
|
from embedchain.models.data_type import DataType
|
||||||
from embedchain.utils import detect_datatype
|
from embedchain.utils import detect_datatype
|
||||||
@@ -32,7 +33,7 @@ CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
|||||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||||
|
|
||||||
|
|
||||||
class EmbedChain:
|
class EmbedChain(JSONSerializable):
|
||||||
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
|
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||||
|
|||||||
180
embedchain/helper_classes/json_serializable.py
Normal file
180
embedchain/helper_classes/json_serializable.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Type, TypeVar, Union
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="JSONSerializable")
|
||||||
|
|
||||||
|
# NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
|
||||||
|
# NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
|
||||||
|
|
||||||
|
|
||||||
|
def register_deserializable(cls: Type[T]) -> Type[T]:
|
||||||
|
"""
|
||||||
|
A class decorator to register a class as deserializable.
|
||||||
|
|
||||||
|
When a class is decorated with @register_deserializable, it becomes
|
||||||
|
a part of the set of classes that the JSONSerializable class can
|
||||||
|
deserialize.
|
||||||
|
|
||||||
|
Deserialization is in essence loading attributes from a json file.
|
||||||
|
This decorator is a security measure put in place to make sure that
|
||||||
|
you don't load attributes that were initially part of another class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_deserializable
|
||||||
|
class ChildClass(JSONSerializable):
|
||||||
|
def __init__(self, ...):
|
||||||
|
# initialization logic
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (Type): The class to be registered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type: The same class, after registration.
|
||||||
|
"""
|
||||||
|
JSONSerializable.register_class_as_deserializable(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class JSONSerializable:
|
||||||
|
"""
|
||||||
|
A class to represent a JSON serializable object.
|
||||||
|
|
||||||
|
This class provides methods to serialize and deserialize objects,
|
||||||
|
as well as save serialized objects to a file and load them back.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_deserializable_classes = set() # Contains classes that are whitelisted for deserialization.
|
||||||
|
|
||||||
|
def serialize(self) -> str:
|
||||||
|
"""
|
||||||
|
Serialize the object to a JSON-formatted string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A JSON string representation of the object.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Serialization error: {e}")
|
||||||
|
return "{}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def deserialize(cls, json_str: str) -> Any:
|
||||||
|
"""
|
||||||
|
Deserialize a JSON-formatted string to an object.
|
||||||
|
If it fails, a default class is returned instead.
|
||||||
|
Note: This *returns* an instance, it's not automatically loaded on the calling class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
app = App.deserialize(json_str)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str (str): A JSON string representation of an object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Object: The deserialized object.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.loads(json_str, object_hook=cls._auto_decoder)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Deserialization error: {e}")
|
||||||
|
# Return a default instance in case of failure
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Automatically encode an object for JSON serialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (Object): The object to be encoded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary representation of the object.
|
||||||
|
"""
|
||||||
|
if hasattr(obj, "__dict__"):
|
||||||
|
dct = obj.__dict__.copy()
|
||||||
|
for key, value in list(
|
||||||
|
dct.items()
|
||||||
|
): # We use list() to get a copy of items to avoid dictionary size change during iteration.
|
||||||
|
try:
|
||||||
|
# Recursive: If the value is an instance of a subclass of JSONSerializable,
|
||||||
|
# serialize it using the JSONSerializable serialize method.
|
||||||
|
if isinstance(value, JSONSerializable):
|
||||||
|
serialized_value = value.serialize()
|
||||||
|
# The value is stored as a serialized string.
|
||||||
|
dct[key] = json.loads(serialized_value)
|
||||||
|
else:
|
||||||
|
json.dumps(value) # Try to serialize the value.
|
||||||
|
except TypeError:
|
||||||
|
del dct[key] # If it fails, remove the key-value pair from the dictionary.
|
||||||
|
|
||||||
|
dct["__class__"] = obj.__class__.__name__
|
||||||
|
return dct
|
||||||
|
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Automatically decode a dictionary to an object during JSON deserialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dct (dict): The dictionary representation of an object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Object: The decoded object or the original dictionary if decoding is not possible.
|
||||||
|
"""
|
||||||
|
class_name = dct.pop("__class__", None)
|
||||||
|
if class_name:
|
||||||
|
if not hasattr(cls, "_deserializable_classes"): # Additional safety check
|
||||||
|
raise AttributeError(f"`{class_name}` has no registry of allowed deserializations.")
|
||||||
|
if class_name not in {cl.__name__ for cl in cls._deserializable_classes}:
|
||||||
|
raise KeyError(f"Deserialization of class `{class_name}` is not allowed.")
|
||||||
|
target_class = next((cl for cl in cls._deserializable_classes if cl.__name__ == class_name), None)
|
||||||
|
if target_class:
|
||||||
|
obj = target_class.__new__(target_class)
|
||||||
|
for key, value in dct.items():
|
||||||
|
default_value = getattr(target_class, key, None)
|
||||||
|
setattr(obj, key, value or default_value)
|
||||||
|
return obj
|
||||||
|
return dct
|
||||||
|
|
||||||
|
def save_to_file(self, filename: str) -> None:
|
||||||
|
"""
|
||||||
|
Save the serialized object to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The path to the file where the object should be saved.
|
||||||
|
"""
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
f.write(self.serialize())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_file(cls, filename: str) -> Any:
|
||||||
|
"""
|
||||||
|
Load and deserialize an object from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The path to the file from which the object should be loaded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Object: The deserialized object.
|
||||||
|
"""
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
json_str = f.read()
|
||||||
|
return cls.deserialize(json_str)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_class_as_deserializable(cls, target_class: Type[T]) -> None:
|
||||||
|
"""
|
||||||
|
Register a class as deserializable. This is a classmethod and globally shared.
|
||||||
|
|
||||||
|
This method adds the target class to the set of classes that
|
||||||
|
can be deserialized. This is a security measure to ensure only
|
||||||
|
whitelisted classes are deserialized.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_class (Type): The class to be registered.
|
||||||
|
"""
|
||||||
|
cls._deserializable_classes.add(target_class)
|
||||||
@@ -1,4 +1,7 @@
|
|||||||
class BaseLoader:
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLoader(JSONSerializable):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ from urllib.parse import urljoin, urlparse
|
|||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class DocsSiteLoader(BaseLoader):
|
class DocsSiteLoader(BaseLoader):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.visited_links = set()
|
self.visited_links = set()
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from langchain.document_loaders import Docx2txtLoader
|
from langchain.document_loaders import Docx2txtLoader
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class DocxFileLoader(BaseLoader):
|
class DocxFileLoader(BaseLoader):
|
||||||
def load_data(self, url):
|
def load_data(self, url):
|
||||||
"""Load data from a .docx file."""
|
"""Load data from a .docx file."""
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class LocalQnaPairLoader(BaseLoader):
|
class LocalQnaPairLoader(BaseLoader):
|
||||||
def load_data(self, content):
|
def load_data(self, content):
|
||||||
"""Load data from a local QnA pair."""
|
"""Load data from a local QnA pair."""
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class LocalTextLoader(BaseLoader):
|
class LocalTextLoader(BaseLoader):
|
||||||
def load_data(self, content):
|
def load_data(self, content):
|
||||||
"""Load data from a local text file."""
|
"""Load data from a local text file."""
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ except ImportError:
|
|||||||
raise ImportError("Notion requires extra dependencies. Install with `pip install embedchain[community]`") from None
|
raise ImportError("Notion requires extra dependencies. Install with `pip install embedchain[community]`") from None
|
||||||
|
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.utils import clean_string
|
from embedchain.utils import clean_string
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class NotionLoader(BaseLoader):
|
class NotionLoader(BaseLoader):
|
||||||
def load_data(self, source):
|
def load_data(self, source):
|
||||||
"""Load data from a PDF file."""
|
"""Load data from a PDF file."""
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from langchain.document_loaders import PyPDFLoader
|
from langchain.document_loaders import PyPDFLoader
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.utils import clean_string
|
from embedchain.utils import clean_string
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class PdfFileLoader(BaseLoader):
|
class PdfFileLoader(BaseLoader):
|
||||||
def load_data(self, url):
|
def load_data(self, url):
|
||||||
"""Load data from a PDF file."""
|
"""Load data from a PDF file."""
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ import requests
|
|||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from bs4.builder import ParserRejectedMarkup
|
from bs4.builder import ParserRejectedMarkup
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.loaders.web_page import WebPageLoader
|
from embedchain.loaders.web_page import WebPageLoader
|
||||||
from embedchain.utils import is_readable
|
from embedchain.utils import is_readable
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class SitemapLoader(BaseLoader):
|
class SitemapLoader(BaseLoader):
|
||||||
def load_data(self, sitemap_url):
|
def load_data(self, sitemap_url):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ import logging
|
|||||||
import requests
|
import requests
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.utils import clean_string
|
from embedchain.utils import clean_string
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class WebPageLoader(BaseLoader):
|
class WebPageLoader(BaseLoader):
|
||||||
def load_data(self, url):
|
def load_data(self, url):
|
||||||
"""Load data from a web page."""
|
"""Load data from a web page."""
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from langchain.document_loaders import YoutubeLoader
|
from langchain.document_loaders import YoutubeLoader
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.loaders.base_loader import BaseLoader
|
from embedchain.loaders.base_loader import BaseLoader
|
||||||
from embedchain.utils import clean_string
|
from embedchain.utils import clean_string
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class YoutubeVideoLoader(BaseLoader):
|
class YoutubeVideoLoader(BaseLoader):
|
||||||
def load_data(self, url):
|
def load_data(self, url):
|
||||||
"""Load data from a Youtube video."""
|
"""Load data from a Youtube video."""
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
class BaseVectorDB:
|
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVectorDB(JSONSerializable):
|
||||||
"""Base class for vector database."""
|
"""Base class for vector database."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -14,9 +14,11 @@ except RuntimeError:
|
|||||||
|
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class ChromaDB(BaseVectorDB):
|
class ChromaDB(BaseVectorDB):
|
||||||
"""Vector database using ChromaDB."""
|
"""Vector database using ChromaDB."""
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,12 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
from embedchain.config import ElasticsearchDBConfig
|
from embedchain.config import ElasticsearchDBConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||||
from embedchain.models.VectorDimensions import VectorDimensions
|
from embedchain.models.VectorDimensions import VectorDimensions
|
||||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||||
|
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
class ElasticsearchDB(BaseVectorDB):
|
class ElasticsearchDB(BaseVectorDB):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
70
tests/helper_classes/test_json_serializable.py
Normal file
70
tests/helper_classes/test_json_serializable.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from embedchain import App
|
||||||
|
from embedchain.config import AppConfig
|
||||||
|
from embedchain.helper_classes.json_serializable import (
|
||||||
|
JSONSerializable, register_deserializable)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonSerializable(unittest.TestCase):
|
||||||
|
"""Test that the datatype detection is working, based on the input."""
|
||||||
|
|
||||||
|
def test_base_function(self):
|
||||||
|
"""Test that the base premise of serialization and deserealization is working"""
|
||||||
|
|
||||||
|
@register_deserializable
|
||||||
|
class TestClass(JSONSerializable):
|
||||||
|
def __init__(self):
|
||||||
|
self.rng = random.random()
|
||||||
|
|
||||||
|
original_class = TestClass()
|
||||||
|
serial = original_class.serialize()
|
||||||
|
|
||||||
|
# Negative test to show that a new class does not have the same random number.
|
||||||
|
negative_test_class = TestClass()
|
||||||
|
self.assertNotEqual(original_class.rng, negative_test_class.rng)
|
||||||
|
|
||||||
|
# Test to show that a deserialized class has the same random number.
|
||||||
|
positive_test_class: TestClass = TestClass().deserialize(serial)
|
||||||
|
self.assertEqual(original_class.rng, positive_test_class.rng)
|
||||||
|
self.assertTrue(isinstance(positive_test_class, TestClass))
|
||||||
|
|
||||||
|
# Test that it works as a static method too.
|
||||||
|
positive_test_class: TestClass = TestClass.deserialize(serial)
|
||||||
|
self.assertEqual(original_class.rng, positive_test_class.rng)
|
||||||
|
|
||||||
|
# TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
|
||||||
|
|
||||||
|
def test_registration_required(self):
|
||||||
|
"""Test that registration is required, and that without registration the default class is returned."""
|
||||||
|
|
||||||
|
class SecondTestClass(JSONSerializable):
|
||||||
|
def __init__(self):
|
||||||
|
self.default = True
|
||||||
|
|
||||||
|
app = SecondTestClass()
|
||||||
|
# Make not default
|
||||||
|
app.default = False
|
||||||
|
# Serialize
|
||||||
|
serial = app.serialize()
|
||||||
|
# Deserialize. Due to the way errors are handled, it will not fail but return a default class.
|
||||||
|
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||||
|
self.assertTrue(app.default)
|
||||||
|
# If we register and try again with the same serial, it should work
|
||||||
|
SecondTestClass.register_class_as_deserializable(SecondTestClass)
|
||||||
|
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||||
|
self.assertFalse(app.default)
|
||||||
|
|
||||||
|
def test_recursive(self):
|
||||||
|
"""Test recursiveness with the real app"""
|
||||||
|
random_id = str(random.random())
|
||||||
|
config = AppConfig(id=random_id)
|
||||||
|
# config class is set under app.config.
|
||||||
|
app = App(config=config)
|
||||||
|
# w/o recursion it would just be <embedchain.config.apps.OpenSourceAppConfig.OpenSourceAppConfig object at x>
|
||||||
|
s = app.serialize()
|
||||||
|
new_app: App = App.deserialize(s)
|
||||||
|
# The id of the new app is the same as the first one.
|
||||||
|
self.assertEqual(random_id, new_app.config.id)
|
||||||
|
# TODO: test deeper recursion
|
||||||
Reference in New Issue
Block a user