fix: serialize non serializable (#589)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from string import Template
|
||||
from typing import Any, Dict, Type, TypeVar, Union
|
||||
|
||||
T = TypeVar("T", bound="JSONSerializable")
|
||||
@@ -105,6 +106,16 @@ class JSONSerializable:
|
||||
serialized_value = value.serialize()
|
||||
# The value is stored as a serialized string.
|
||||
dct[key] = json.loads(serialized_value)
|
||||
# Custom rules (subclass is not json serializable by default)
|
||||
elif isinstance(value, Template):
|
||||
dct[key] = {"__type__": "Template", "data": value.template}
|
||||
# Future custom types we can follow a similar pattern
|
||||
# elif isinstance(value, SomeOtherType):
|
||||
# dct[key] = {
|
||||
# "__type__": "SomeOtherType",
|
||||
# "data": value.some_method()
|
||||
# }
|
||||
# NOTE: Keep in mind that this logic needs to be applied to the decoder too.
|
||||
else:
|
||||
json.dumps(value) # Try to serialize the value.
|
||||
except TypeError:
|
||||
@@ -135,6 +146,12 @@ class JSONSerializable:
|
||||
if target_class:
|
||||
obj = target_class.__new__(target_class)
|
||||
for key, value in dct.items():
|
||||
if isinstance(value, dict) and "__type__" in value:
|
||||
if value["__type__"] == "Template":
|
||||
value = Template(value["data"])
|
||||
# For future custom types we can follow a similar pattern
|
||||
# elif value["__type__"] == "SomeOtherType":
|
||||
# value = SomeOtherType.some_constructor(value["data"])
|
||||
default_value = getattr(target_class, key, None)
|
||||
setattr(obj, key, value or default_value)
|
||||
return obj
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import random
|
||||
import unittest
|
||||
from string import Template
|
||||
|
||||
from embedchain import App
|
||||
from embedchain.config import AppConfig
|
||||
from embedchain.config import AppConfig, BaseLlmConfig
|
||||
from embedchain.helper.json_serializable import (JSONSerializable,
|
||||
register_deserializable)
|
||||
|
||||
@@ -69,3 +70,11 @@ class TestJsonSerializable(unittest.TestCase):
|
||||
self.assertEqual(random_id, new_app.config.id)
|
||||
# We have proven that a nested class (app.config) can be serialized and deserialized just the same.
|
||||
# TODO: test deeper recursion
|
||||
|
||||
def test_special_subclasses(self):
|
||||
"""Test special subclasses that are not serializable by default."""
|
||||
# Template
|
||||
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
|
||||
s = config.serialize()
|
||||
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
|
||||
self.assertEqual(config.template.template, new_config.template.template)
|
||||
|
||||
Reference in New Issue
Block a user