Feat/serialize deserialize (#508)
Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
This commit is contained in:
@@ -4,8 +4,10 @@ import openai
|
||||
|
||||
from embedchain.config import AppConfig, ChatConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class App(EmbedChain):
|
||||
"""
|
||||
The EmbedChain app.
|
||||
|
||||
@@ -5,9 +5,11 @@ from langchain.schema import BaseMessage
|
||||
|
||||
from embedchain.config import ChatConfig, CustomAppConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models import Providers
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CustomApp(EmbedChain):
|
||||
"""
|
||||
The custom EmbedChain app.
|
||||
|
||||
@@ -3,10 +3,12 @@ from typing import Iterable, Union, Optional
|
||||
|
||||
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
||||
from embedchain.embedchain import EmbedChain
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
gpt4all_model = None
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenSourceApp(EmbedChain):
|
||||
"""
|
||||
The OpenSource app.
|
||||
|
||||
@@ -6,8 +6,10 @@ from embedchain.config import ChatConfig, QueryConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
|
||||
DEFAULT_PROMPT_WITH_HISTORY)
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class EmbedChainPersonApp:
|
||||
"""
|
||||
Base class to create a person bot.
|
||||
@@ -50,6 +52,7 @@ class EmbedChainPersonApp:
|
||||
return config
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PersonApp(EmbedChainPersonApp, App):
|
||||
"""
|
||||
The Person app.
|
||||
@@ -65,6 +68,7 @@ class PersonApp(EmbedChainPersonApp, App):
|
||||
return super().chat(input_query, config, dry_run)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
|
||||
"""
|
||||
The Person app.
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from embedchain import CustomApp
|
||||
from embedchain.config import AddConfig, CustomAppConfig, QueryConfig
|
||||
from embedchain.helper_classes.json_serializable import (
|
||||
JSONSerializable, register_deserializable)
|
||||
from embedchain.models import EmbeddingFunctions, Providers
|
||||
|
||||
|
||||
class BaseBot:
|
||||
@register_deserializable
|
||||
class BaseBot(JSONSerializable):
|
||||
def __init__(self, app_config=None):
|
||||
if app_config is None:
|
||||
app_config = CustomAppConfig(embedding_fn=EmbeddingFunctions.OPENAI, provider=Providers.OPENAI)
|
||||
|
||||
@@ -6,10 +6,12 @@ from typing import List, Optional
|
||||
from fastapi_poe import PoeBot, run
|
||||
|
||||
from embedchain.config import QueryConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class EcPoeBot(BaseBot, PoeBot):
|
||||
def __init__(self):
|
||||
self.history_length = 5
|
||||
|
||||
@@ -6,9 +6,12 @@ import sys
|
||||
from flask import Flask, request
|
||||
from twilio.twiml.messaging_response import MessagingResponse
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .base import BaseBot
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WhatsAppBot(BaseBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import hashlib
|
||||
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class BaseChunker:
|
||||
class BaseChunker(JSONSerializable):
|
||||
def __init__(self, text_splitter):
|
||||
"""Initialize the chunker."""
|
||||
self.text_splitter = text_splitter
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocsSiteChunker(BaseChunker):
|
||||
"""Chunker for code docs site."""
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocxFileChunker(BaseChunker):
|
||||
"""Chunker for .docx file."""
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class NotionChunker(BaseChunker):
|
||||
"""Chunker for notion."""
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PdfFileChunker(BaseChunker):
|
||||
"""Chunker for PDF file."""
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class QnaPairChunker(BaseChunker):
|
||||
"""Chunker for QnA pair."""
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class TextChunker(BaseChunker):
|
||||
"""Chunker for text."""
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WebPageChunker(BaseChunker):
|
||||
"""Chunker for web page."""
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
from embedchain.chunkers.base_chunker import BaseChunker
|
||||
from embedchain.config.AddConfig import ChunkerConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class YoutubeVideoChunker(BaseChunker):
|
||||
"""Chunker for Youtube video."""
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChunkerConfig(BaseConfig):
|
||||
"""
|
||||
Config for the chunker used in `add` method
|
||||
@@ -19,6 +21,7 @@ class ChunkerConfig(BaseConfig):
|
||||
self.length_function = length_function if length_function else len
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class LoaderConfig(BaseConfig):
|
||||
"""
|
||||
Config for the chunker used in `add` method
|
||||
@@ -28,6 +31,7 @@ class LoaderConfig(BaseConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AddConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `add` method.
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
class BaseConfig:
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseConfig(JSONSerializable):
|
||||
"""
|
||||
Base config.
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ from string import Template
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.QueryConfig import QueryConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
You are a chatbot having a conversation with a human. You are given chat
|
||||
@@ -20,6 +21,7 @@ DEFAULT_PROMPT = """
|
||||
DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChatConfig(QueryConfig):
|
||||
"""
|
||||
Config for the `chat` method, inherits from `QueryConfig`.
|
||||
|
||||
@@ -3,6 +3,7 @@ from string import Template
|
||||
from typing import Optional
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
Use the following pieces of context to answer the query at the end.
|
||||
@@ -48,6 +49,7 @@ context_re = re.compile(r"\$\{*context\}*")
|
||||
history_re = re.compile(r"\$\{*history\}*")
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class QueryConfig(BaseConfig):
|
||||
"""
|
||||
Config for the `query` method.
|
||||
|
||||
@@ -9,9 +9,12 @@ except RuntimeError:
|
||||
use_pysqlite3()
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class AppConfig(BaseAppConfig):
|
||||
"""
|
||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||
|
||||
@@ -2,10 +2,11 @@ import logging
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.models import VectorDatabases, VectorDimensions
|
||||
|
||||
|
||||
class BaseAppConfig(BaseConfig):
|
||||
class BaseAppConfig(BaseConfig, JSONSerializable):
|
||||
"""
|
||||
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,7 @@ from chromadb.api.types import Documents, Embeddings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from embedchain.config.vectordbs import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models import (EmbeddingFunctions, Providers, VectorDatabases,
|
||||
VectorDimensions)
|
||||
|
||||
@@ -12,6 +13,7 @@ from .BaseAppConfig import BaseAppConfig
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CustomAppConfig(BaseAppConfig):
|
||||
"""
|
||||
Config to initialize an embedchain custom `App` instance, with extra config options.
|
||||
|
||||
@@ -2,9 +2,12 @@ from typing import Optional
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
from .BaseAppConfig import BaseAppConfig
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class OpenSourceAppConfig(BaseAppConfig):
|
||||
"""
|
||||
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from embedchain.config.BaseConfig import BaseConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDBConfig(BaseConfig):
|
||||
"""
|
||||
Config to initialize an elasticsearch client.
|
||||
|
||||
@@ -7,6 +7,7 @@ from embedchain.chunkers.text import TextChunker
|
||||
from embedchain.chunkers.web_page import WebPageChunker
|
||||
from embedchain.chunkers.youtube_video import YoutubeVideoChunker
|
||||
from embedchain.config import AddConfig
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.loaders.docs_site_loader import DocsSiteLoader
|
||||
from embedchain.loaders.docx_file import DocxFileLoader
|
||||
from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
|
||||
@@ -18,7 +19,7 @@ from embedchain.loaders.youtube_video import YoutubeVideoLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
|
||||
class DataFormatter:
|
||||
class DataFormatter(JSONSerializable):
|
||||
"""
|
||||
DataFormatter is an internal utility class which abstracts the mapping for
|
||||
loaders and chunkers to the data_type entered by the user in their
|
||||
|
||||
@@ -19,6 +19,7 @@ from embedchain.config import AddConfig, ChatConfig, QueryConfig
|
||||
from embedchain.config.apps.BaseAppConfig import BaseAppConfig
|
||||
from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
|
||||
from embedchain.data_formatter import DataFormatter
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.models.data_type import DataType
|
||||
from embedchain.utils import detect_datatype
|
||||
@@ -32,7 +33,7 @@ CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
|
||||
CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
|
||||
|
||||
|
||||
class EmbedChain:
|
||||
class EmbedChain(JSONSerializable):
|
||||
def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
|
||||
"""
|
||||
Initializes the EmbedChain instance, sets up a vector DB client and
|
||||
|
||||
180
embedchain/helper_classes/json_serializable.py
Normal file
180
embedchain/helper_classes/json_serializable.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Type, TypeVar, Union
|
||||
|
||||
T = TypeVar("T", bound="JSONSerializable")
|
||||
|
||||
# NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
|
||||
# NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
|
||||
|
||||
|
||||
def register_deserializable(cls: Type[T]) -> Type[T]:
|
||||
"""
|
||||
A class decorator to register a class as deserializable.
|
||||
|
||||
When a class is decorated with @register_deserializable, it becomes
|
||||
a part of the set of classes that the JSONSerializable class can
|
||||
deserialize.
|
||||
|
||||
Deserialization is in essence loading attributes from a json file.
|
||||
This decorator is a security measure put in place to make sure that
|
||||
you don't load attributes that were initially part of another class.
|
||||
|
||||
Example:
|
||||
@register_deserializable
|
||||
class ChildClass(JSONSerializable):
|
||||
def __init__(self, ...):
|
||||
# initialization logic
|
||||
|
||||
Args:
|
||||
cls (Type): The class to be registered.
|
||||
|
||||
Returns:
|
||||
Type: The same class, after registration.
|
||||
"""
|
||||
JSONSerializable.register_class_as_deserializable(cls)
|
||||
return cls
|
||||
|
||||
|
||||
class JSONSerializable:
|
||||
"""
|
||||
A class to represent a JSON serializable object.
|
||||
|
||||
This class provides methods to serialize and deserialize objects,
|
||||
as well as save serialized objects to a file and load them back.
|
||||
"""
|
||||
|
||||
_deserializable_classes = set() # Contains classes that are whitelisted for deserialization.
|
||||
|
||||
def serialize(self) -> str:
|
||||
"""
|
||||
Serialize the object to a JSON-formatted string.
|
||||
|
||||
Returns:
|
||||
str: A JSON string representation of the object.
|
||||
"""
|
||||
try:
|
||||
return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Serialization error: {e}")
|
||||
return "{}"
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, json_str: str) -> Any:
|
||||
"""
|
||||
Deserialize a JSON-formatted string to an object.
|
||||
If it fails, a default class is returned instead.
|
||||
Note: This *returns* an instance, it's not automatically loaded on the calling class.
|
||||
|
||||
Example:
|
||||
app = App.deserialize(json_str)
|
||||
|
||||
Args:
|
||||
json_str (str): A JSON string representation of an object.
|
||||
|
||||
Returns:
|
||||
Object: The deserialized object.
|
||||
"""
|
||||
try:
|
||||
return json.loads(json_str, object_hook=cls._auto_decoder)
|
||||
except Exception as e:
|
||||
logging.error(f"Deserialization error: {e}")
|
||||
# Return a default instance in case of failure
|
||||
return cls()
|
||||
|
||||
@staticmethod
|
||||
def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
|
||||
"""
|
||||
Automatically encode an object for JSON serialization.
|
||||
|
||||
Args:
|
||||
obj (Object): The object to be encoded.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary representation of the object.
|
||||
"""
|
||||
if hasattr(obj, "__dict__"):
|
||||
dct = obj.__dict__.copy()
|
||||
for key, value in list(
|
||||
dct.items()
|
||||
): # We use list() to get a copy of items to avoid dictionary size change during iteration.
|
||||
try:
|
||||
# Recursive: If the value is an instance of a subclass of JSONSerializable,
|
||||
# serialize it using the JSONSerializable serialize method.
|
||||
if isinstance(value, JSONSerializable):
|
||||
serialized_value = value.serialize()
|
||||
# The value is stored as a serialized string.
|
||||
dct[key] = json.loads(serialized_value)
|
||||
else:
|
||||
json.dumps(value) # Try to serialize the value.
|
||||
except TypeError:
|
||||
del dct[key] # If it fails, remove the key-value pair from the dictionary.
|
||||
|
||||
dct["__class__"] = obj.__class__.__name__
|
||||
return dct
|
||||
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
||||
|
||||
@classmethod
|
||||
def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Automatically decode a dictionary to an object during JSON deserialization.
|
||||
|
||||
Args:
|
||||
dct (dict): The dictionary representation of an object.
|
||||
|
||||
Returns:
|
||||
Object: The decoded object or the original dictionary if decoding is not possible.
|
||||
"""
|
||||
class_name = dct.pop("__class__", None)
|
||||
if class_name:
|
||||
if not hasattr(cls, "_deserializable_classes"): # Additional safety check
|
||||
raise AttributeError(f"`{class_name}` has no registry of allowed deserializations.")
|
||||
if class_name not in {cl.__name__ for cl in cls._deserializable_classes}:
|
||||
raise KeyError(f"Deserialization of class `{class_name}` is not allowed.")
|
||||
target_class = next((cl for cl in cls._deserializable_classes if cl.__name__ == class_name), None)
|
||||
if target_class:
|
||||
obj = target_class.__new__(target_class)
|
||||
for key, value in dct.items():
|
||||
default_value = getattr(target_class, key, None)
|
||||
setattr(obj, key, value or default_value)
|
||||
return obj
|
||||
return dct
|
||||
|
||||
def save_to_file(self, filename: str) -> None:
|
||||
"""
|
||||
Save the serialized object to a file.
|
||||
|
||||
Args:
|
||||
filename (str): The path to the file where the object should be saved.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(self.serialize())
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, filename: str) -> Any:
|
||||
"""
|
||||
Load and deserialize an object from a file.
|
||||
|
||||
Args:
|
||||
filename (str): The path to the file from which the object should be loaded.
|
||||
|
||||
Returns:
|
||||
Object: The deserialized object.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
json_str = f.read()
|
||||
return cls.deserialize(json_str)
|
||||
|
||||
@classmethod
|
||||
def register_class_as_deserializable(cls, target_class: Type[T]) -> None:
|
||||
"""
|
||||
Register a class as deserializable. This is a classmethod and globally shared.
|
||||
|
||||
This method adds the target class to the set of classes that
|
||||
can be deserialized. This is a security measure to ensure only
|
||||
whitelisted classes are deserialized.
|
||||
|
||||
Args:
|
||||
target_class (Type): The class to be registered.
|
||||
"""
|
||||
cls._deserializable_classes.add(target_class)
|
||||
@@ -1,4 +1,7 @@
|
||||
class BaseLoader:
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseLoader(JSONSerializable):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@ from urllib.parse import urljoin, urlparse
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocsSiteLoader(BaseLoader):
|
||||
def __init__(self):
|
||||
self.visited_links = set()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from langchain.document_loaders import Docx2txtLoader
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class DocxFileLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a .docx file."""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class LocalQnaPairLoader(BaseLoader):
|
||||
def load_data(self, content):
|
||||
"""Load data from a local QnA pair."""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class LocalTextLoader(BaseLoader):
|
||||
def load_data(self, content):
|
||||
"""Load data from a local text file."""
|
||||
|
||||
@@ -7,10 +7,12 @@ except ImportError:
|
||||
raise ImportError("Notion requires extra dependencies. Install with `pip install embedchain[community]`") from None
|
||||
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class NotionLoader(BaseLoader):
|
||||
def load_data(self, source):
|
||||
"""Load data from a PDF file."""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from langchain.document_loaders import PyPDFLoader
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class PdfFileLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a PDF file."""
|
||||
|
||||
@@ -4,11 +4,13 @@ import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.builder import ParserRejectedMarkup
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.loaders.web_page import WebPageLoader
|
||||
from embedchain.utils import is_readable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class SitemapLoader(BaseLoader):
|
||||
def load_data(self, sitemap_url):
|
||||
"""
|
||||
|
||||
@@ -3,10 +3,12 @@ import logging
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class WebPageLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a web page."""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from langchain.document_loaders import YoutubeLoader
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.loaders.base_loader import BaseLoader
|
||||
from embedchain.utils import clean_string
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class YoutubeVideoLoader(BaseLoader):
|
||||
def load_data(self, url):
|
||||
"""Load data from a Youtube video."""
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
class BaseVectorDB:
|
||||
from embedchain.helper_classes.json_serializable import JSONSerializable
|
||||
|
||||
|
||||
class BaseVectorDB(JSONSerializable):
|
||||
"""Base class for vector database."""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -14,9 +14,11 @@ except RuntimeError:
|
||||
|
||||
from chromadb.config import Settings
|
||||
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ChromaDB(BaseVectorDB):
|
||||
"""Vector database using ChromaDB."""
|
||||
|
||||
|
||||
@@ -9,10 +9,12 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
from embedchain.config import ElasticsearchDBConfig
|
||||
from embedchain.helper_classes.json_serializable import register_deserializable
|
||||
from embedchain.models.VectorDimensions import VectorDimensions
|
||||
from embedchain.vectordb.base_vector_db import BaseVectorDB
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class ElasticsearchDB(BaseVectorDB):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
70
tests/helper_classes/test_json_serializable.py
Normal file
70
tests/helper_classes/test_json_serializable.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import random
|
||||
import unittest
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.helper_classes.json_serializable import (
|
||||
JSONSerializable, register_deserializable)
|
||||
|
||||
|
||||
class TestJsonSerializable(unittest.TestCase):
|
||||
"""Test that the datatype detection is working, based on the input."""
|
||||
|
||||
def test_base_function(self):
|
||||
"""Test that the base premise of serialization and deserealization is working"""
|
||||
|
||||
@register_deserializable
|
||||
class TestClass(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.rng = random.random()
|
||||
|
||||
original_class = TestClass()
|
||||
serial = original_class.serialize()
|
||||
|
||||
# Negative test to show that a new class does not have the same random number.
|
||||
negative_test_class = TestClass()
|
||||
self.assertNotEqual(original_class.rng, negative_test_class.rng)
|
||||
|
||||
# Test to show that a deserialized class has the same random number.
|
||||
positive_test_class: TestClass = TestClass().deserialize(serial)
|
||||
self.assertEqual(original_class.rng, positive_test_class.rng)
|
||||
self.assertTrue(isinstance(positive_test_class, TestClass))
|
||||
|
||||
# Test that it works as a static method too.
|
||||
positive_test_class: TestClass = TestClass.deserialize(serial)
|
||||
self.assertEqual(original_class.rng, positive_test_class.rng)
|
||||
|
||||
# TODO: There's no reason it shouldn't work, but serialization to and from file should be tested too.
|
||||
|
||||
def test_registration_required(self):
|
||||
"""Test that registration is required, and that without registration the default class is returned."""
|
||||
|
||||
class SecondTestClass(JSONSerializable):
|
||||
def __init__(self):
|
||||
self.default = True
|
||||
|
||||
app = SecondTestClass()
|
||||
# Make not default
|
||||
app.default = False
|
||||
# Serialize
|
||||
serial = app.serialize()
|
||||
# Deserialize. Due to the way errors are handled, it will not fail but return a default class.
|
||||
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||
self.assertTrue(app.default)
|
||||
# If we register and try again with the same serial, it should work
|
||||
SecondTestClass.register_class_as_deserializable(SecondTestClass)
|
||||
app: SecondTestClass = SecondTestClass().deserialize(serial)
|
||||
self.assertFalse(app.default)
|
||||
|
||||
def test_recursive(self):
|
||||
"""Test recursiveness with the real app"""
|
||||
random_id = str(random.random())
|
||||
config = AppConfig(id=random_id)
|
||||
# config class is set under app.config.
|
||||
app = App(config=config)
|
||||
# w/o recursion it would just be <embedchain.config.apps.OpenSourceAppConfig.OpenSourceAppConfig object at x>
|
||||
s = app.serialize()
|
||||
new_app: App = App.deserialize(s)
|
||||
# The id of the new app is the same as the first one.
|
||||
self.assertEqual(random_id, new_app.config.id)
|
||||
# TODO: test deeper recursion
|
||||
Reference in New Issue
Block a user