fix: serialize non serializable (#589)

This commit is contained in:
cachho
2023-09-12 05:55:46 +02:00
committed by GitHub
parent 7c39d9f0c1
commit 1864f4cb38
2 changed files with 27 additions and 1 deletions

View File

@@ -1,5 +1,6 @@
import json
import logging
from string import Template
from typing import Any, Dict, Type, TypeVar, Union
T = TypeVar("T", bound="JSONSerializable")
@@ -105,6 +106,16 @@ class JSONSerializable:
serialized_value = value.serialize()
# The value is stored as a serialized string.
dct[key] = json.loads(serialized_value)
# Custom rules (subclass is not json serializable by default)
elif isinstance(value, Template):
dct[key] = {"__type__": "Template", "data": value.template}
# Future custom types we can follow a similar pattern
# elif isinstance(value, SomeOtherType):
# dct[key] = {
# "__type__": "SomeOtherType",
# "data": value.some_method()
# }
# NOTE: Keep in mind that this logic needs to be applied to the decoder too.
else:
json.dumps(value) # Try to serialize the value.
except TypeError:
@@ -135,6 +146,12 @@ class JSONSerializable:
if target_class:
obj = target_class.__new__(target_class)
for key, value in dct.items():
if isinstance(value, dict) and "__type__" in value:
if value["__type__"] == "Template":
value = Template(value["data"])
# For future custom types we can follow a similar pattern
# elif value["__type__"] == "SomeOtherType":
# value = SomeOtherType.some_constructor(value["data"])
default_value = getattr(target_class, key, None)
setattr(obj, key, value or default_value)
return obj

View File

@@ -1,8 +1,9 @@
import random
import unittest
from string import Template
from embedchain import App
from embedchain.config import AppConfig
from embedchain.config import AppConfig, BaseLlmConfig
from embedchain.helper.json_serializable import (JSONSerializable,
register_deserializable)
@@ -69,3 +70,11 @@ class TestJsonSerializable(unittest.TestCase):
self.assertEqual(random_id, new_app.config.id)
# We have proven that a nested class (app.config) can be serialized and deserialized just the same.
# TODO: test deeper recursion
def test_special_subclasses(self):
"""Test special subclasses that are not serializable by default."""
# Template
config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
s = config.serialize()
new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
self.assertEqual(config.template.template, new_config.template.template)