fix: serialize non serializable (#589)

This commit is contained in:
cachho
2023-09-12 05:55:46 +02:00
committed by GitHub
parent 7c39d9f0c1
commit 1864f4cb38
2 changed files with 27 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
import json import json
import logging import logging
from string import Template
from typing import Any, Dict, Type, TypeVar, Union from typing import Any, Dict, Type, TypeVar, Union
T = TypeVar("T", bound="JSONSerializable") T = TypeVar("T", bound="JSONSerializable")
@@ -105,6 +106,16 @@ class JSONSerializable:
serialized_value = value.serialize() serialized_value = value.serialize()
# The value is stored as a serialized string. # The value is stored as a serialized string.
dct[key] = json.loads(serialized_value) dct[key] = json.loads(serialized_value)
# Custom rules (subclass is not json serializable by default)
elif isinstance(value, Template):
dct[key] = {"__type__": "Template", "data": value.template}
# Future custom types we can follow a similar pattern
# elif isinstance(value, SomeOtherType):
# dct[key] = {
# "__type__": "SomeOtherType",
# "data": value.some_method()
# }
# NOTE: Keep in mind that this logic needs to be applied to the decoder too.
else: else:
json.dumps(value) # Try to serialize the value. json.dumps(value) # Try to serialize the value.
except TypeError: except TypeError:
@@ -135,6 +146,12 @@ class JSONSerializable:
if target_class: if target_class:
obj = target_class.__new__(target_class) obj = target_class.__new__(target_class)
for key, value in dct.items(): for key, value in dct.items():
if isinstance(value, dict) and "__type__" in value:
if value["__type__"] == "Template":
value = Template(value["data"])
# For future custom types we can follow a similar pattern
# elif value["__type__"] == "SomeOtherType":
# value = SomeOtherType.some_constructor(value["data"])
default_value = getattr(target_class, key, None) default_value = getattr(target_class, key, None)
setattr(obj, key, value or default_value) setattr(obj, key, value or default_value)
return obj return obj

View File

@@ -1,8 +1,9 @@
import random import random
import unittest import unittest
from string import Template
from embedchain import App from embedchain import App
from embedchain.config import AppConfig from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.helper.json_serializable import (JSONSerializable, from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable) register_deserializable)
@@ -69,3 +70,11 @@ class TestJsonSerializable(unittest.TestCase):
self.assertEqual(random_id, new_app.config.id) self.assertEqual(random_id, new_app.config.id)
# We have proven that a nested class (app.config) can be serialized and deserialized just the same. # We have proven that a nested class (app.config) can be serialized and deserialized just the same.
# TODO: test deeper recursion # TODO: test deeper recursion
def test_special_subclasses(self):
"""Test special subclasses that are not serializable by default."""
# Template
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
s = config.serialize()
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
self.assertEqual(config.template.template, new_config.template.template)