Feat/serialize deserialize (#508)

Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
This commit is contained in:
cachho
2023-09-03 21:50:18 +02:00
committed by GitHub
parent 2aa25a5169
commit 0d4ad07d7b
42 changed files with 345 additions and 8 deletions

View File

@@ -4,8 +4,10 @@ import openai
from embedchain.config import AppConfig, ChatConfig
from embedchain.embedchain import EmbedChain
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class App(EmbedChain):
"""
The EmbedChain app.

View File

@@ -5,9 +5,11 @@ from langchain.schema import BaseMessage
from embedchain.config import ChatConfig, CustomAppConfig
from embedchain.embedchain import EmbedChain
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models import Providers
@register_deserializable
class CustomApp(EmbedChain):
"""
The custom EmbedChain app.

View File

@@ -3,10 +3,12 @@ from typing import Iterable, Union, Optional
from embedchain.config import ChatConfig, OpenSourceAppConfig
from embedchain.embedchain import EmbedChain
from embedchain.helper_classes.json_serializable import register_deserializable
gpt4all_model = None
@register_deserializable
class OpenSourceApp(EmbedChain):
"""
The OpenSource app.

View File

@@ -6,8 +6,10 @@ from embedchain.config import ChatConfig, QueryConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY)
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class EmbedChainPersonApp:
"""
Base class to create a person bot.
@@ -50,6 +52,7 @@ class EmbedChainPersonApp:
return config
@register_deserializable
class PersonApp(EmbedChainPersonApp, App):
"""
The Person app.
@@ -65,6 +68,7 @@ class PersonApp(EmbedChainPersonApp, App):
return super().chat(input_query, config, dry_run)
@register_deserializable
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
"""
The Person app.

View File

@@ -1,9 +1,12 @@
from embedchain import CustomApp
from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
from embedchain.helper_classes.json_serializable import (
JSONSerializable, register_deserializable)
from embedchain.models import EmbeddingFunctions, Providers
class BaseBot:
@register_deserializable
class BaseBot(JSONSerializable):
def __init__(self, app_config=None):
if app_config is None:
app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)

View File

@@ -6,10 +6,12 @@ from typing import List, Optional
from fastapi_poe import PoeBot, run
from embedchain.config import QueryConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from .base import BaseBot
@register_deserializable
class EcPoeBot(BaseBot, PoeBot):
def __init__(self):
self.history_length = 5

View File

@@ -6,9 +6,12 @@ import sys
from flask import Flask, request
from twilio.twiml.messaging_response import MessagingResponse
from embedchain.helper_classes.json_serializable import register_deserializable
from .base import BaseBot
@register_deserializable
class WhatsAppBot(BaseBot):
def __init__(self):
super().__init__()

View File

@@ -1,9 +1,10 @@
import hashlib
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.models.data_type import DataType
class BaseChunker:
class BaseChunker(JSONSerializable):
def __init__(self, text_splitter):
"""Initialize the chunker."""
self.text_splitter = text_splitter

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class DocsSiteChunker(BaseChunker):
"""Chunker for code docs site."""

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class DocxFileChunker(BaseChunker):
"""Chunker for .docx file."""

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class NotionChunker(BaseChunker):
"""Chunker for notion."""

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class PdfFileChunker(BaseChunker):
"""Chunker for PDF file."""

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class QnaPairChunker(BaseChunker):
"""Chunker for QnA pair."""

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class TextChunker(BaseChunker):
"""Chunker for text."""

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class WebPageChunker(BaseChunker):
"""Chunker for web page."""

View File

@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
from embedchain.chunkers.base_chunker import BaseChunker
from embedchain.config.AddConfig import ChunkerConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class YoutubeVideoChunker(BaseChunker):
"""Chunker for Youtube video."""

View File

@@ -1,8 +1,10 @@
from typing import Callable, Optional
from embedchain.config.BaseConfig import BaseConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class ChunkerConfig(BaseConfig):
"""
Config for the chunker used in `add` method
@@ -19,6 +21,7 @@ class ChunkerConfig(BaseConfig):
self.length_function = length_function if length_function else len
@register_deserializable
class LoaderConfig(BaseConfig):
"""
Config for the chunker used in `add` method
@@ -28,6 +31,7 @@ class LoaderConfig(BaseConfig):
pass
@register_deserializable
class AddConfig(BaseConfig):
"""
Config for the `add` method.

View File

@@ -1,4 +1,7 @@
class BaseConfig:
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseConfig(JSONSerializable):
"""
Base config.
"""

View File

@@ -2,6 +2,7 @@ from string import Template
from typing import Optional
from embedchain.config.QueryConfig import QueryConfig
from embedchain.helper_classes.json_serializable import register_deserializable
DEFAULT_PROMPT = """
You are a chatbot having a conversation with a human. You are given chat
@@ -20,6 +21,7 @@ DEFAULT_PROMPT = """
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
@register_deserializable
class ChatConfig(QueryConfig):
"""
Config for the `chat` method, inherits from `QueryConfig`.

View File

@@ -3,6 +3,7 @@ from string import Template
from typing import Optional
from embedchain.config.BaseConfig import BaseConfig
from embedchain.helper_classes.json_serializable import register_deserializable
DEFAULT_PROMPT = """
Use the following pieces of context to answer the query at the end.
@@ -48,6 +49,7 @@ context_re = re.compile(r"\$\{*context\}*")
history_re = re.compile(r"\$\{*history\}*")
@register_deserializable
class QueryConfig(BaseConfig):
"""
Config for the `query` method.

View File

@@ -9,9 +9,12 @@ except RuntimeError:
use_pysqlite3()
from chromadb.utils import embedding_functions
from embedchain.helper_classes.json_serializable import register_deserializable
from .BaseAppConfig import BaseAppConfig
@register_deserializable
class AppConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `App` instance, with extra config options.

View File

@@ -2,10 +2,11 @@ import logging
from embedchain.config.BaseConfig import BaseConfig
from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.models import VectorDatabases, VectorDimensions
class BaseAppConfig(BaseConfig):
class BaseAppConfig(BaseConfig, JSONSerializable):
"""
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
"""

View File

@@ -4,6 +4,7 @@ from chromadb.api.types import Documents, Embeddings
from dotenv import load_dotenv
from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
VectorDimensions)
@@ -12,6 +13,7 @@ from .BaseAppConfig import BaseAppConfig
load_dotenv()
@register_deserializable
class CustomAppConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `App` instance, with extra config options.

View File

@@ -2,9 +2,12 @@ from typing import Optional
from chromadb.utils import embedding_functions
from embedchain.helper_classes.json_serializable import register_deserializable
from .BaseAppConfig import BaseAppConfig
@register_deserializable
class OpenSourceAppConfig(BaseAppConfig):
"""
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.

View File

@@ -1,8 +1,10 @@
from typing import Dict, List, Union
from embedchain.config.BaseConfig import BaseConfig
from embedchain.helper_classes.json_serializable import register_deserializable
@register_deserializable
class ElasticsearchDBConfig(BaseConfig):
"""
Config to initialize an elasticsearch client.

View File

@@ -7,6 +7,7 @@ from embedchain.chunkers.text import TextChunker
from embedchain.chunkers.web_page import WebPageChunker
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
from embedchain.config import AddConfig
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.loaders.docs_site_loader import DocsSiteLoader
from embedchain.loaders.docx_file import DocxFileLoader
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
@@ -18,7 +19,7 @@ from embedchain.loaders.youtube_video import YoutubeVideoLoader
from embedchain.models.data_type import DataType
class DataFormatter:
class DataFormatter(JSONSerializable):
"""
DataFormatter is an internal utility class which abstracts the mapping for
loaders and chunkers to the data_type entered by the user in their

View File

@@ -19,6 +19,7 @@ from embedchain.config import AddConfig, ChatConfig, QueryConfig
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
from embedchain.data_formatter import DataFormatter
from embedchain.helper_classes.json_serializable import JSONSerializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.models.data_type import DataType
from embedchain.utils import detect_datatype
@@ -32,7 +33,7 @@ CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
class EmbedChain:
class EmbedChain(JSONSerializable):
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
"""
Initializes the EmbedChain instance, sets up a vector DB client and

View File

@@ -0,0 +1,180 @@
import json
import logging
from typing import Any, Dict, Type, TypeVar, Union
T = TypeVar("T", bound="JSONSerializable")
# NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
# NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
def register_deserializable(cls: Type[T]) -> Type[T]:
"""
A class decorator to register a class as deserializable.
When a class is decorated with @register_deserializable, it becomes
a part of the set of classes that the JSONSerializable class can
deserialize.
Deserialization is in essence loading attributes from a json file.
This decorator is a security measure put in place to make sure that
you don't load attributes that were initially part of another class.
Example:
@register_deserializable
class ChildClass(JSONSerializable):
def __init__(self, ...):
# initialization logic
Args:
cls (Type): The class to be registered.
Returns:
Type: The same class, after registration.
"""
JSONSerializable.register_class_as_deserializable(cls)
return cls
class JSONSerializable:
"""
A class to represent a JSON serializable object.
This class provides methods to serialize and deserialize objects,
as well as save serialized objects to a file and load them back.
"""
_deserializable_classes = set() # Contains classes that are whitelisted for deserialization.
def serialize(self) -> str:
"""
Serialize the object to a JSON-formatted string.
Returns:
str: A JSON string representation of the object.
"""
try:
return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
except Exception as e:
logging.error(f"Serialization error: {e}")
return "{}"
@classmethod
def deserialize(cls, json_str: str) -> Any:
"""
Deserialize a JSON-formatted string to an object.
If it fails, a default class is returned instead.
Note: This *returns* an instance, it's not automatically loaded on the calling class.
Example:
app = App.deserialize(json_str)
Args:
json_str (str): A JSON string representation of an object.
Returns:
Object: The deserialized object.
"""
try:
return json.loads(json_str, object_hook=cls._auto_decoder)
except Exception as e:
logging.error(f"Deserialization error: {e}")
# Return a default instance in case of failure
return cls()
@staticmethod
def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
"""
Automatically encode an object for JSON serialization.
Args:
obj (Object): The object to be encoded.
Returns:
dict: A dictionary representation of the object.
"""
if hasattr(obj, "__dict__"):
dct = obj.__dict__.copy()
for key, value in list(
dct.items()
): # We use list() to get a copy of items to avoid dictionary size change during iteration.
try:
# Recursive: If the value is an instance of a subclass of JSONSerializable,
# serialize it using the JSONSerializable serialize method.
if isinstance(value, JSONSerializable):
serialized_value = value.serialize()
# The value is stored as a serialized string.
dct[key] = json.loads(serialized_value)
else:
json.dumps(value) # Try to serialize the value.
except TypeError:
del dct[key] # If it fails, remove the key-value pair from the dictionary.
dct["__class__"] = obj.__class__.__name__
return dct
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
@classmethod
def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
"""
Automatically decode a dictionary to an object during JSON deserialization.
Args:
dct (dict): The dictionary representation of an object.
Returns:
Object: The decoded object or the original dictionary if decoding is not possible.
"""
class_name = dct.pop("__class__", None)
if class_name:
if not hasattr(cls, "_deserializable_classes"): # Additional safety check
raise AttributeError(f"`{class_name}` has no registry of allowed deserializations.")
if class_name not in {cl.__name__ for cl in cls._deserializable_classes}:
raise KeyError(f"Deserialization of class `{class_name}` is not allowed.")
target_class = next((cl for cl in cls._deserializable_classes if cl.__name__ == class_name), None)
if target_class:
obj = target_class.__new__(target_class)
for key, value in dct.items():
default_value = getattr(target_class, key, None)
setattr(obj, key, value or default_value)
return obj
return dct
def save_to_file(self, filename: str) -> None:
"""
Save the serialized object to a file.
Args:
filename (str): The path to the file where the object should be saved.
"""
with open(filename, "w", encoding="utf-8") as f:
f.write(self.serialize())
@classmethod
def load_from_file(cls, filename: str) -> Any:
"""
Load and deserialize an object from a file.
Args:
filename (str): The path to the file from which the object should be loaded.
Returns:
Object: The deserialized object.
"""
with open(filename, "r", encoding="utf-8") as f:
json_str = f.read()
return cls.deserialize(json_str)
@classmethod
def register_class_as_deserializable(cls, target_class: Type[T]) -> None:
"""
Register a class as deserializable. This is a classmethod and globally shared.
This method adds the target class to the set of classes that
can be deserialized. This is a security measure to ensure only
whitelisted classes are deserialized.
Args:
target_class (Type): The class to be registered.
"""
cls._deserializable_classes.add(target_class)

View File

@@ -1,4 +1,7 @@
class BaseLoader:
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseLoader(JSONSerializable):
def __init__(self):
pass

View File

@@ -4,9 +4,11 @@ from urllib.parse import urljoin, urlparse
import requests
from bs4 import BeautifulSoup
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class DocsSiteLoader(BaseLoader):
def __init__(self):
self.visited_links = set()

View File

@@ -1,8 +1,10 @@
from langchain.document_loaders import Docx2txtLoader
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class DocxFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a .docx file."""

View File

@@ -1,6 +1,8 @@
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class LocalQnaPairLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local QnA pair."""

View File

@@ -1,6 +1,8 @@
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
@register_deserializable
class LocalTextLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local text file."""

View File

@@ -7,10 +7,12 @@ except ImportError:
raise ImportError("Notion requires extra dependencies. Install with `pip install embedchain[community]`") from None
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
@register_deserializable
class NotionLoader(BaseLoader):
def load_data(self, source):
"""Load data from a PDF file."""

View File

@@ -1,9 +1,11 @@
from langchain.document_loaders import PyPDFLoader
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
@register_deserializable
class PdfFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a PDF file."""

View File

@@ -4,11 +4,13 @@ import requests
from bs4 import BeautifulSoup
from bs4.builder import ParserRejectedMarkup
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.loaders.web_page import WebPageLoader
from embedchain.utils import is_readable
@register_deserializable
class SitemapLoader(BaseLoader):
def load_data(self, sitemap_url):
"""

View File

@@ -3,10 +3,12 @@ import logging
import requests
from bs4 import BeautifulSoup
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
@register_deserializable
class WebPageLoader(BaseLoader):
def load_data(self, url):
"""Load data from a web page."""

View File

@@ -1,9 +1,11 @@
from langchain.document_loaders import YoutubeLoader
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader
from embedchain.utils import clean_string
@register_deserializable
class YoutubeVideoLoader(BaseLoader):
def load_data(self, url):
"""Load data from a Youtube video."""

View File

@@ -1,4 +1,7 @@
class BaseVectorDB:
from embedchain.helper_classes.json_serializable import JSONSerializable
class BaseVectorDB(JSONSerializable):
"""Base class for vector database."""
def __init__(self):

View File

@@ -14,9 +14,11 @@ except RuntimeError:
from chromadb.config import Settings
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.vectordb.base_vector_db import BaseVectorDB
@register_deserializable
class ChromaDB(BaseVectorDB):
"""Vector database using ChromaDB."""

View File

@@ -9,10 +9,12 @@ except ImportError:
) from None
from embedchain.config import ElasticsearchDBConfig
from embedchain.helper_classes.json_serializable import register_deserializable
from embedchain.models.VectorDimensions import VectorDimensions
from embedchain.vectordb.base_vector_db import BaseVectorDB
@register_deserializable
class ElasticsearchDB(BaseVectorDB):
def __init__(
self,

View File

@@ -0,0 +1,70 @@
import random
import unittest
from embedchain import App
from embedchain.config import AppConfig
from embedchain.helper_classes.json_serializable import (
JSONSerializable, register_deserializable)
class TestJsonSerializable(unittest.TestCase):
"""Test that the datatype detection is working, based on the input."""
def test_base_function(self):
"""Test that the base premise of serialization and deserealization is working"""
@register_deserializable
class TestClass(JSONSerializable):
def __init__(self):
self.rng = random.random()
original_class = TestClass()
serial = original_class.serialize()
# Negative test to show that a new class does not have the same random number.
negative_test_class = TestClass()
self.assertNotEqual(original_class.rng, negative_test_class.rng)
# Test to show that a deserialized class has the same random number.
positive_test_class: TestClass = TestClass().deserialize(serial)
self.assertEqual(original_class.rng, positive_test_class.rng)
self.assertTrue(isinstance(positive_test_class, TestClass))
# Test that it works as a static method too.
positive_test_class: TestClass = TestClass.deserialize(serial)
self.assertEqual(original_class.rng, positive_test_class.rng)
# TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
def test_registration_required(self):
"""Test that registration is required, and that without registration the default class is returned."""
class SecondTestClass(JSONSerializable):
def __init__(self):
self.default = True
app = SecondTestClass()
# Make not default
app.default = False
# Serialize
serial = app.serialize()
# Deserialize. Due to the way errors are handled, it will not fail but return a default class.
app: SecondTestClass = SecondTestClass().deserialize(serial)
self.assertTrue(app.default)
# If we register and try again with the same serial, it should work
SecondTestClass.register_class_as_deserializable(SecondTestClass)
app: SecondTestClass = SecondTestClass().deserialize(serial)
self.assertFalse(app.default)
def test_recursive(self):
"""Test recursiveness with the real app"""
random_id = str(random.random())
config = AppConfig(id=random_id)
# config class is set under app.config.
app = App(config=config)
# w/o recursion it would just be <embedchain.config.apps.OpenSourceAppConfig.OpenSourceAppConfig object at x>
s = app.serialize()
new_app: App = App.deserialize(s)
# The id of the new app is the same as the first one.
self.assertEqual(random_id, new_app.config.id)
# TODO: test deeper recursion