From 0d4ad07d7b6118a2b8c1d4af2ae8452aebae4c60 Mon Sep 17 00:00:00 2001 From: cachho Date: Sun, 3 Sep 2023 21:50:18 +0200 Subject: [PATCH] Feat/serialize deserialize (#508) Co-authored-by: Taranjeet Singh --- embedchain/apps/App.py | 2 + embedchain/apps/CustomApp.py | 2 + embedchain/apps/OpenSourceApp.py | 2 + embedchain/apps/PersonApp.py | 4 + embedchain/bots/base.py | 5 +- embedchain/bots/poe.py | 2 + embedchain/bots/whatsapp.py | 3 + embedchain/chunkers/base_chunker.py | 3 +- embedchain/chunkers/docs_site.py | 2 + embedchain/chunkers/docx_file.py | 2 + embedchain/chunkers/notion.py | 2 + embedchain/chunkers/pdf_file.py | 2 + embedchain/chunkers/qna_pair.py | 2 + embedchain/chunkers/text.py | 2 + embedchain/chunkers/web_page.py | 2 + embedchain/chunkers/youtube_video.py | 2 + embedchain/config/AddConfig.py | 4 + embedchain/config/BaseConfig.py | 5 +- embedchain/config/ChatConfig.py | 2 + embedchain/config/QueryConfig.py | 2 + embedchain/config/apps/AppConfig.py | 3 + embedchain/config/apps/BaseAppConfig.py | 3 +- embedchain/config/apps/CustomAppConfig.py | 2 + embedchain/config/apps/OpenSourceAppConfig.py | 3 + .../config/vectordbs/ElasticsearchDBConfig.py | 2 + embedchain/data_formatter/data_formatter.py | 3 +- embedchain/embedchain.py | 3 +- .../helper_classes/json_serializable.py | 180 ++++++++++++++++++ embedchain/loaders/base_loader.py | 5 +- embedchain/loaders/docs_site_loader.py | 2 + embedchain/loaders/docx_file.py | 2 + embedchain/loaders/local_qna_pair.py | 2 + embedchain/loaders/local_text.py | 2 + embedchain/loaders/notion.py | 2 + embedchain/loaders/pdf_file.py | 2 + embedchain/loaders/sitemap.py | 2 + embedchain/loaders/web_page.py | 2 + embedchain/loaders/youtube_video.py | 2 + embedchain/vectordb/base_vector_db.py | 5 +- embedchain/vectordb/chroma_db.py | 2 + embedchain/vectordb/elasticsearch_db.py | 2 + .../helper_classes/test_json_serializable.py | 70 +++++++ 42 files changed, 345 insertions(+), 8 deletions(-) create mode 100644 embedchain/helper_classes/json_serializable.py create mode 100644 tests/helper_classes/test_json_serializable.py diff --git a/embedchain/apps/App.py b/embedchain/apps/App.py index 03d9f73a..693fd804 100644 --- a/embedchain/apps/App.py +++ b/embedchain/apps/App.py @@ -4,8 +4,10 @@ import openai from embedchain.config import AppConfig, ChatConfig from embedchain.embedchain import EmbedChain +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class App(EmbedChain): """ The EmbedChain app. diff --git a/embedchain/apps/CustomApp.py b/embedchain/apps/CustomApp.py index 5478e124..ad20d1eb 100644 --- a/embedchain/apps/CustomApp.py +++ b/embedchain/apps/CustomApp.py @@ -5,9 +5,11 @@ from langchain.schema import BaseMessage from embedchain.config import ChatConfig, CustomAppConfig from embedchain.embedchain import EmbedChain +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.models import Providers +@register_deserializable class CustomApp(EmbedChain): """ The custom EmbedChain app. diff --git a/embedchain/apps/OpenSourceApp.py b/embedchain/apps/OpenSourceApp.py index a74433fc..bd9f3914 100644 --- a/embedchain/apps/OpenSourceApp.py +++ b/embedchain/apps/OpenSourceApp.py @@ -3,10 +3,12 @@ from typing import Iterable, Union, Optional from embedchain.config import ChatConfig, OpenSourceAppConfig from embedchain.embedchain import EmbedChain +from embedchain.helper_classes.json_serializable import register_deserializable gpt4all_model = None +@register_deserializable class OpenSourceApp(EmbedChain): """ The OpenSource app. diff --git a/embedchain/apps/PersonApp.py b/embedchain/apps/PersonApp.py index a229c8db..d38719d6 100644 --- a/embedchain/apps/PersonApp.py +++ b/embedchain/apps/PersonApp.py @@ -6,8 +6,10 @@ from embedchain.config import ChatConfig, QueryConfig from embedchain.config.apps.BaseAppConfig import BaseAppConfig from embedchain.config.QueryConfig import (DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY) +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class EmbedChainPersonApp: """ Base class to create a person bot. @@ -50,6 +52,7 @@ class EmbedChainPersonApp: return config +@register_deserializable class PersonApp(EmbedChainPersonApp, App): """ The Person app. @@ -65,6 +68,7 @@ class PersonApp(EmbedChainPersonApp, App): return super().chat(input_query, config, dry_run) +@register_deserializable class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp): """ The Person app. diff --git a/embedchain/bots/base.py b/embedchain/bots/base.py index bc94b769..48cfb71c 100644 --- a/embedchain/bots/base.py +++ b/embedchain/bots/base.py @@ -1,9 +1,12 @@ from embedchain import CustomApp from embedchain.config import AddConfig, CustomAppConfig, QueryConfig +from embedchain.helper_classes.json_serializable import ( + JSONSerializable, register_deserializable) from embedchain.models import EmbeddingFunctions, Providers -class BaseBot: +@register_deserializable +class BaseBot(JSONSerializable): def __init__(self, app_config=None): if app_config is None: app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI) diff --git a/embedchain/bots/poe.py b/embedchain/bots/poe.py index 23ee383a..69b5b914 100644 --- a/embedchain/bots/poe.py +++ b/embedchain/bots/poe.py @@ -6,10 +6,12 @@ from typing import List, Optional from fastapi_poe import PoeBot, run from embedchain.config import QueryConfig +from embedchain.helper_classes.json_serializable import register_deserializable from .base import BaseBot +@register_deserializable class EcPoeBot(BaseBot, PoeBot): def __init__(self): self.history_length = 5 diff --git a/embedchain/bots/whatsapp.py b/embedchain/bots/whatsapp.py index 6ea64155..2d241183 100644 --- a/embedchain/bots/whatsapp.py +++ b/embedchain/bots/whatsapp.py @@ -6,9 +6,12 @@ import sys from flask import Flask, request from twilio.twiml.messaging_response import MessagingResponse +from embedchain.helper_classes.json_serializable import register_deserializable + from .base import BaseBot +@register_deserializable class WhatsAppBot(BaseBot): def __init__(self): super().__init__() diff --git a/embedchain/chunkers/base_chunker.py b/embedchain/chunkers/base_chunker.py index 8bee1e54..2a926d3e 100644 --- a/embedchain/chunkers/base_chunker.py +++ b/embedchain/chunkers/base_chunker.py @@ -1,9 +1,10 @@ import hashlib +from embedchain.helper_classes.json_serializable import JSONSerializable from embedchain.models.data_type import DataType -class BaseChunker: +class BaseChunker(JSONSerializable): def __init__(self, text_splitter): """Initialize the chunker.""" self.text_splitter = text_splitter diff --git a/embedchain/chunkers/docs_site.py b/embedchain/chunkers/docs_site.py index c84621c2..e5ea879f 100644 --- a/embedchain/chunkers/docs_site.py +++ b/embedchain/chunkers/docs_site.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class DocsSiteChunker(BaseChunker): """Chunker for code docs site.""" diff --git a/embedchain/chunkers/docx_file.py b/embedchain/chunkers/docx_file.py index bccfb589..51935d84 100644 --- a/embedchain/chunkers/docx_file.py +++ b/embedchain/chunkers/docx_file.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class DocxFileChunker(BaseChunker): """Chunker for .docx file.""" diff --git a/embedchain/chunkers/notion.py b/embedchain/chunkers/notion.py index 3ea8012d..5b473e0f 100644 --- a/embedchain/chunkers/notion.py +++ b/embedchain/chunkers/notion.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class NotionChunker(BaseChunker): """Chunker for notion.""" diff --git a/embedchain/chunkers/pdf_file.py b/embedchain/chunkers/pdf_file.py index ec19166b..1a482553 100644 --- a/embedchain/chunkers/pdf_file.py +++ b/embedchain/chunkers/pdf_file.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class PdfFileChunker(BaseChunker): """Chunker for PDF file.""" diff --git a/embedchain/chunkers/qna_pair.py b/embedchain/chunkers/qna_pair.py index ba9d0991..451ac010 100644 --- a/embedchain/chunkers/qna_pair.py +++ b/embedchain/chunkers/qna_pair.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class QnaPairChunker(BaseChunker): """Chunker for QnA pair.""" diff --git a/embedchain/chunkers/text.py b/embedchain/chunkers/text.py index 44a320d1..0329d959 100644 --- a/embedchain/chunkers/text.py +++ b/embedchain/chunkers/text.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class TextChunker(BaseChunker): """Chunker for text.""" diff --git a/embedchain/chunkers/web_page.py b/embedchain/chunkers/web_page.py index fd451d8e..395cf250 100644 --- a/embedchain/chunkers/web_page.py +++ b/embedchain/chunkers/web_page.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class WebPageChunker(BaseChunker): """Chunker for web page.""" diff --git a/embedchain/chunkers/youtube_video.py b/embedchain/chunkers/youtube_video.py index 4f2ad41f..41d27b0f 100644 --- a/embedchain/chunkers/youtube_video.py +++ b/embedchain/chunkers/youtube_video.py @@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from embedchain.chunkers.base_chunker import BaseChunker from embedchain.config.AddConfig import ChunkerConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class YoutubeVideoChunker(BaseChunker): """Chunker for Youtube video.""" diff --git a/embedchain/config/AddConfig.py b/embedchain/config/AddConfig.py index fe527a1d..935a6195 100644 --- a/embedchain/config/AddConfig.py +++ b/embedchain/config/AddConfig.py @@ -1,8 +1,10 @@ from typing import Callable, Optional from embedchain.config.BaseConfig import BaseConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class ChunkerConfig(BaseConfig): """ Config for the chunker used in `add` method @@ -19,6 +21,7 @@ class ChunkerConfig(BaseConfig): self.length_function = length_function if length_function else len +@register_deserializable class LoaderConfig(BaseConfig): """ Config for the chunker used in `add` method @@ -28,6 +31,7 @@ class LoaderConfig(BaseConfig): pass +@register_deserializable class AddConfig(BaseConfig): """ Config for the `add` method. diff --git a/embedchain/config/BaseConfig.py b/embedchain/config/BaseConfig.py index 38d53f52..4526214f 100644 --- a/embedchain/config/BaseConfig.py +++ b/embedchain/config/BaseConfig.py @@ -1,4 +1,7 @@ -class BaseConfig: +from embedchain.helper_classes.json_serializable import JSONSerializable + + +class BaseConfig(JSONSerializable): """ Base config. """ diff --git a/embedchain/config/ChatConfig.py b/embedchain/config/ChatConfig.py index fd195a90..0c403869 100644 --- a/embedchain/config/ChatConfig.py +++ b/embedchain/config/ChatConfig.py @@ -2,6 +2,7 @@ from string import Template from typing import Optional from embedchain.config.QueryConfig import QueryConfig +from embedchain.helper_classes.json_serializable import register_deserializable DEFAULT_PROMPT = """ You are a chatbot having a conversation with a human. You are given chat @@ -20,6 +21,7 @@ DEFAULT_PROMPT = """ DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT) +@register_deserializable class ChatConfig(QueryConfig): """ Config for the `chat` method, inherits from `QueryConfig`. diff --git a/embedchain/config/QueryConfig.py b/embedchain/config/QueryConfig.py index 3a0ceef7..a8c703a4 100644 --- a/embedchain/config/QueryConfig.py +++ b/embedchain/config/QueryConfig.py @@ -3,6 +3,7 @@ from string import Template from typing import Optional from embedchain.config.BaseConfig import BaseConfig +from embedchain.helper_classes.json_serializable import register_deserializable DEFAULT_PROMPT = """ Use the following pieces of context to answer the query at the end. @@ -48,6 +49,7 @@ context_re = re.compile(r"\$\{*context\}*") history_re = re.compile(r"\$\{*history\}*") +@register_deserializable class QueryConfig(BaseConfig): """ Config for the `query` method. diff --git a/embedchain/config/apps/AppConfig.py b/embedchain/config/apps/AppConfig.py index 995a45ff..5e74957c 100644 --- a/embedchain/config/apps/AppConfig.py +++ b/embedchain/config/apps/AppConfig.py @@ -9,9 +9,12 @@ except RuntimeError: use_pysqlite3() from chromadb.utils import embedding_functions +from embedchain.helper_classes.json_serializable import register_deserializable + from .BaseAppConfig import BaseAppConfig +@register_deserializable class AppConfig(BaseAppConfig): """ Config to initialize an embedchain custom `App` instance, with extra config options. diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index 5706fd49..0fed691f 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -2,10 +2,11 @@ import logging from embedchain.config.BaseConfig import BaseConfig from embedchain.config.vectordbs import ElasticsearchDBConfig +from embedchain.helper_classes.json_serializable import JSONSerializable from embedchain.models import VectorDatabases, VectorDimensions -class BaseAppConfig(BaseConfig): +class BaseAppConfig(BaseConfig, JSONSerializable): """ Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`. """ diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index 677c6aeb..332ef0e8 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -4,6 +4,7 @@ from chromadb.api.types import Documents, Embeddings from dotenv import load_dotenv from embedchain.config.vectordbs import ElasticsearchDBConfig +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases, VectorDimensions) @@ -12,6 +13,7 @@ from .BaseAppConfig import BaseAppConfig load_dotenv() +@register_deserializable class CustomAppConfig(BaseAppConfig): """ Config to initialize an embedchain custom `App` instance, with extra config options. diff --git a/embedchain/config/apps/OpenSourceAppConfig.py b/embedchain/config/apps/OpenSourceAppConfig.py index 7907deee..a0dd4ca4 100644 --- a/embedchain/config/apps/OpenSourceAppConfig.py +++ b/embedchain/config/apps/OpenSourceAppConfig.py @@ -2,9 +2,12 @@ from typing import Optional from chromadb.utils import embedding_functions +from embedchain.helper_classes.json_serializable import register_deserializable + from .BaseAppConfig import BaseAppConfig +@register_deserializable class OpenSourceAppConfig(BaseAppConfig): """ Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options. diff --git a/embedchain/config/vectordbs/ElasticsearchDBConfig.py b/embedchain/config/vectordbs/ElasticsearchDBConfig.py index 6e7dd0f9..691bb778 100644 --- a/embedchain/config/vectordbs/ElasticsearchDBConfig.py +++ b/embedchain/config/vectordbs/ElasticsearchDBConfig.py @@ -1,8 +1,10 @@ from typing import Dict, List, Union from embedchain.config.BaseConfig import BaseConfig +from embedchain.helper_classes.json_serializable import register_deserializable +@register_deserializable class ElasticsearchDBConfig(BaseConfig): """ Config to initialize an elasticsearch client. diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py index c8bed26b..4691e56d 100644 --- a/embedchain/data_formatter/data_formatter.py +++ b/embedchain/data_formatter/data_formatter.py @@ -7,6 +7,7 @@ from embedchain.chunkers.text import TextChunker from embedchain.chunkers.web_page import WebPageChunker from embedchain.chunkers.youtube_video import YoutubeVideoChunker from embedchain.config import AddConfig +from embedchain.helper_classes.json_serializable import JSONSerializable from embedchain.loaders.docs_site_loader import DocsSiteLoader from embedchain.loaders.docx_file import DocxFileLoader from embedchain.loaders.local_qna_pair import LocalQnaPairLoader @@ -18,7 +19,7 @@ from embedchain.loaders.youtube_video import YoutubeVideoLoader from embedchain.models.data_type import DataType -class DataFormatter: +class DataFormatter(JSONSerializable): """ DataFormatter is an internal utility class which abstracts the mapping for loaders and chunkers to the data_type entered by the user in their diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index d6ef82c4..d379863b 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -19,6 +19,7 @@ from embedchain.config import AddConfig, ChatConfig, QueryConfig from embedchain.config.apps.BaseAppConfig import BaseAppConfig from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE from embedchain.data_formatter import DataFormatter +from embedchain.helper_classes.json_serializable import JSONSerializable from embedchain.loaders.base_loader import BaseLoader from embedchain.models.data_type import DataType from embedchain.utils import detect_datatype @@ -32,7 +33,7 @@ CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain") CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json") -class EmbedChain: +class EmbedChain(JSONSerializable): def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None): """ Initializes the EmbedChain instance, sets up a vector DB client and diff --git a/embedchain/helper_classes/json_serializable.py b/embedchain/helper_classes/json_serializable.py new file mode 100644 index 00000000..9537e9f9 --- /dev/null +++ b/embedchain/helper_classes/json_serializable.py @@ -0,0 +1,180 @@ +import json +import logging +from typing import Any, Dict, Type, TypeVar, Union + +T = TypeVar("T", bound="JSONSerializable") + +# NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level) +# NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level) + + +def register_deserializable(cls: Type[T]) -> Type[T]: + """ + A class decorator to register a class as deserializable. + + When a class is decorated with @register_deserializable, it becomes + a part of the set of classes that the JSONSerializable class can + deserialize. + + Deserialization is in essence loading attributes from a json file. + This decorator is a security measure put in place to make sure that + you don't load attributes that were initially part of another class. + + Example: + @register_deserializable + class ChildClass(JSONSerializable): + def __init__(self, ...): + # initialization logic + + Args: + cls (Type): The class to be registered. + + Returns: + Type: The same class, after registration. + """ + JSONSerializable.register_class_as_deserializable(cls) + return cls + + +class JSONSerializable: + """ + A class to represent a JSON serializable object. + + This class provides methods to serialize and deserialize objects, + as well as save serialized objects to a file and load them back. + """ + + _deserializable_classes = set() # Contains classes that are whitelisted for deserialization. + + def serialize(self) -> str: + """ + Serialize the object to a JSON-formatted string. + + Returns: + str: A JSON string representation of the object. + """ + try: + return json.dumps(self, default=self._auto_encoder, ensure_ascii=False) + except Exception as e: + logging.error(f"Serialization error: {e}") + return "{}" + + @classmethod + def deserialize(cls, json_str: str) -> Any: + """ + Deserialize a JSON-formatted string to an object. + If it fails, a default class is returned instead. + Note: This *returns* an instance, it's not automatically loaded on the calling class. + + Example: + app = App.deserialize(json_str) + + Args: + json_str (str): A JSON string representation of an object. + + Returns: + Object: The deserialized object. + """ + try: + return json.loads(json_str, object_hook=cls._auto_decoder) + except Exception as e: + logging.error(f"Deserialization error: {e}") + # Return a default instance in case of failure + return cls() + + @staticmethod + def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]: + """ + Automatically encode an object for JSON serialization. + + Args: + obj (Object): The object to be encoded. + + Returns: + dict: A dictionary representation of the object. + """ + if hasattr(obj, "__dict__"): + dct = obj.__dict__.copy() + for key, value in list( + dct.items() + ): # We use list() to get a copy of items to avoid dictionary size change during iteration. + try: + # Recursive: If the value is an instance of a subclass of JSONSerializable, + # serialize it using the JSONSerializable serialize method. + if isinstance(value, JSONSerializable): + serialized_value = value.serialize() + # The value is stored as a serialized string. + dct[key] = json.loads(serialized_value) + else: + json.dumps(value) # Try to serialize the value. + except TypeError: + del dct[key] # If it fails, remove the key-value pair from the dictionary. + + dct["__class__"] = obj.__class__.__name__ + return dct + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + @classmethod + def _auto_decoder(cls, dct: Dict[str, Any]) -> Any: + """ + Automatically decode a dictionary to an object during JSON deserialization. + + Args: + dct (dict): The dictionary representation of an object. + + Returns: + Object: The decoded object or the original dictionary if decoding is not possible. + """ + class_name = dct.pop("__class__", None) + if class_name: + if not hasattr(cls, "_deserializable_classes"): # Additional safety check + raise AttributeError(f"`{class_name}` has no registry of allowed deserializations.") + if class_name not in {cl.__name__ for cl in cls._deserializable_classes}: + raise KeyError(f"Deserialization of class `{class_name}` is not allowed.") + target_class = next((cl for cl in cls._deserializable_classes if cl.__name__ == class_name), None) + if target_class: + obj = target_class.__new__(target_class) + for key, value in dct.items(): + default_value = getattr(target_class, key, None) + setattr(obj, key, value or default_value) + return obj + return dct + + def save_to_file(self, filename: str) -> None: + """ + Save the serialized object to a file. + + Args: + filename (str): The path to the file where the object should be saved. + """ + with open(filename, "w", encoding="utf-8") as f: + f.write(self.serialize()) + + @classmethod + def load_from_file(cls, filename: str) -> Any: + """ + Load and deserialize an object from a file. + + Args: + filename (str): The path to the file from which the object should be loaded. + + Returns: + Object: The deserialized object. + """ + with open(filename, "r", encoding="utf-8") as f: + json_str = f.read() + return cls.deserialize(json_str) + + @classmethod + def register_class_as_deserializable(cls, target_class: Type[T]) -> None: + """ + Register a class as deserializable. This is a classmethod and globally shared. + + This method adds the target class to the set of classes that + can be deserialized. This is a security measure to ensure only + whitelisted classes are deserialized. + + Args: + target_class (Type): The class to be registered. + """ + cls._deserializable_classes.add(target_class) diff --git a/embedchain/loaders/base_loader.py b/embedchain/loaders/base_loader.py index 83048518..c0861f88 100644 --- a/embedchain/loaders/base_loader.py +++ b/embedchain/loaders/base_loader.py @@ -1,4 +1,7 @@ -class BaseLoader: +from embedchain.helper_classes.json_serializable import JSONSerializable + + +class BaseLoader(JSONSerializable): def __init__(self): pass diff --git a/embedchain/loaders/docs_site_loader.py b/embedchain/loaders/docs_site_loader.py index 63d6bba0..06d351d8 100644 --- a/embedchain/loaders/docs_site_loader.py +++ b/embedchain/loaders/docs_site_loader.py @@ -4,9 +4,11 @@ from urllib.parse import urljoin, urlparse import requests from bs4 import BeautifulSoup +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader +@register_deserializable class DocsSiteLoader(BaseLoader): def __init__(self): self.visited_links = set() diff --git a/embedchain/loaders/docx_file.py b/embedchain/loaders/docx_file.py index 9a304939..c657d78b 100644 --- a/embedchain/loaders/docx_file.py +++ b/embedchain/loaders/docx_file.py @@ -1,8 +1,10 @@ from langchain.document_loaders import Docx2txtLoader +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader +@register_deserializable class DocxFileLoader(BaseLoader): def load_data(self, url): """Load data from a .docx file.""" diff --git a/embedchain/loaders/local_qna_pair.py b/embedchain/loaders/local_qna_pair.py index 673c009e..4fd1571a 100644 --- a/embedchain/loaders/local_qna_pair.py +++ b/embedchain/loaders/local_qna_pair.py @@ -1,6 +1,8 @@ +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader +@register_deserializable class LocalQnaPairLoader(BaseLoader): def load_data(self, content): """Load data from a local QnA pair.""" diff --git a/embedchain/loaders/local_text.py b/embedchain/loaders/local_text.py index 779b2036..92e26a9e 100644 --- a/embedchain/loaders/local_text.py +++ b/embedchain/loaders/local_text.py @@ -1,6 +1,8 @@ +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader +@register_deserializable class LocalTextLoader(BaseLoader): def load_data(self, content): """Load data from a local text file.""" diff --git a/embedchain/loaders/notion.py b/embedchain/loaders/notion.py index 43ac32ca..bc7e7a37 100644 --- a/embedchain/loaders/notion.py +++ b/embedchain/loaders/notion.py @@ -7,10 +7,12 @@ except ImportError: raise ImportError("Notion requires extra dependencies. Install with `pip install embedchain[community]`") from None +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string +@register_deserializable class NotionLoader(BaseLoader): def load_data(self, source): """Load data from a PDF file.""" diff --git a/embedchain/loaders/pdf_file.py b/embedchain/loaders/pdf_file.py index 06e88ca9..7844b145 100644 --- a/embedchain/loaders/pdf_file.py +++ b/embedchain/loaders/pdf_file.py @@ -1,9 +1,11 @@ from langchain.document_loaders import PyPDFLoader +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string +@register_deserializable class PdfFileLoader(BaseLoader): def load_data(self, url): """Load data from a PDF file.""" diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index 4442dec5..3a9109b8 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -4,11 +4,13 @@ import requests from bs4 import BeautifulSoup from bs4.builder import ParserRejectedMarkup +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.loaders.web_page import WebPageLoader from embedchain.utils import is_readable +@register_deserializable class SitemapLoader(BaseLoader): def load_data(self, sitemap_url): """ diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index 417898ea..cc562499 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -3,10 +3,12 @@ import logging import requests from bs4 import BeautifulSoup +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string +@register_deserializable class WebPageLoader(BaseLoader): def load_data(self, url): """Load data from a web page.""" diff --git a/embedchain/loaders/youtube_video.py b/embedchain/loaders/youtube_video.py index 5cc6cc0d..af36fba0 100644 --- a/embedchain/loaders/youtube_video.py +++ b/embedchain/loaders/youtube_video.py @@ -1,9 +1,11 @@ from langchain.document_loaders import YoutubeLoader +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.loaders.base_loader import BaseLoader from embedchain.utils import clean_string +@register_deserializable class YoutubeVideoLoader(BaseLoader): def load_data(self, url): """Load data from a Youtube video.""" diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py index 0ed1e3c0..f9740c5f 100644 --- a/embedchain/vectordb/base_vector_db.py +++ b/embedchain/vectordb/base_vector_db.py @@ -1,4 +1,7 @@ -class BaseVectorDB: +from embedchain.helper_classes.json_serializable import JSONSerializable + + +class BaseVectorDB(JSONSerializable): """Base class for vector database.""" def __init__(self): diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 715721d0..1f8097f6 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -14,9 +14,11 @@ except RuntimeError: from chromadb.config import Settings +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.vectordb.base_vector_db import BaseVectorDB +@register_deserializable class ChromaDB(BaseVectorDB): """Vector database using ChromaDB.""" diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py index 4c767b18..6c75909a 100644 --- a/embedchain/vectordb/elasticsearch_db.py +++ b/embedchain/vectordb/elasticsearch_db.py @@ -9,10 +9,12 @@ except ImportError: ) from None from embedchain.config import ElasticsearchDBConfig +from embedchain.helper_classes.json_serializable import register_deserializable from embedchain.models.VectorDimensions import VectorDimensions from embedchain.vectordb.base_vector_db import BaseVectorDB +@register_deserializable class ElasticsearchDB(BaseVectorDB): def __init__( self, diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py new file mode 100644 index 00000000..aaa6705d --- /dev/null +++ b/tests/helper_classes/test_json_serializable.py @@ -0,0 +1,70 @@ +import random +import unittest + +from embedchain import App +from embedchain.config import AppConfig +from embedchain.helper_classes.json_serializable import ( + JSONSerializable, register_deserializable) + + +class TestJsonSerializable(unittest.TestCase): + """Test that the datatype detection is working, based on the input.""" + + def test_base_function(self): + """Test that the base premise of serialization and deserealization is working""" + + @register_deserializable + class TestClass(JSONSerializable): + def __init__(self): + self.rng = random.random() + + original_class = TestClass() + serial = original_class.serialize() + + # Negative test to show that a new class does not have the same random number. + negative_test_class = TestClass() + self.assertNotEqual(original_class.rng, negative_test_class.rng) + + # Test to show that a deserialized class has the same random number. + positive_test_class: TestClass = TestClass().deserialize(serial) + self.assertEqual(original_class.rng, positive_test_class.rng) + self.assertTrue(isinstance(positive_test_class, TestClass)) + + # Test that it works as a static method too. + positive_test_class: TestClass = TestClass.deserialize(serial) + self.assertEqual(original_class.rng, positive_test_class.rng) + + # TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too. + + def test_registration_required(self): + """Test that registration is required, and that without registration the default class is returned.""" + + class SecondTestClass(JSONSerializable): + def __init__(self): + self.default = True + + app = SecondTestClass() + # Make not default + app.default = False + # Serialize + serial = app.serialize() + # Deserialize. Due to the way errors are handled, it will not fail but return a default class. + app: SecondTestClass = SecondTestClass().deserialize(serial) + self.assertTrue(app.default) + # If we register and try again with the same serial, it should work + SecondTestClass.register_class_as_deserializable(SecondTestClass) + app: SecondTestClass = SecondTestClass().deserialize(serial) + self.assertFalse(app.default) + + def test_recursive(self): + """Test recursiveness with the real app""" + random_id = str(random.random()) + config = AppConfig(id=random_id) + # config class is set under app.config. + app = App(config=config) + # w/o recursion it would just be + s = app.serialize() + new_app: App = App.deserialize(s) + # The id of the new app is the same as the first one. + self.assertEqual(random_id, new_app.config.id) + # TODO: test deeper recursion