diff --git a/embedchain/helper/json_serializable.py b/embedchain/helper/json_serializable.py index 9537e9f9..b5fb7c41 100644 --- a/embedchain/helper/json_serializable.py +++ b/embedchain/helper/json_serializable.py @@ -1,5 +1,6 @@ import json import logging +from string import Template from typing import Any, Dict, Type, TypeVar, Union T = TypeVar("T", bound="JSONSerializable") @@ -105,6 +106,16 @@ class JSONSerializable: serialized_value = value.serialize() # The value is stored as a serialized string. dct[key] = json.loads(serialized_value) + # Custom rules (subclass is not json serializable by default) + elif isinstance(value, Template): + dct[key] = {"__type__": "Template", "data": value.template} + # Future custom types we can follow a similar pattern + # elif isinstance(value, SomeOtherType): + # dct[key] = { + # "__type__": "SomeOtherType", + # "data": value.some_method() + # } + # NOTE: Keep in mind that this logic needs to be applied to the decoder too. else: json.dumps(value) # Try to serialize the value. except TypeError: @@ -135,6 +146,12 @@ class JSONSerializable: if target_class: obj = target_class.__new__(target_class) for key, value in dct.items(): + if isinstance(value, dict) and "__type__" in value: + if value["__type__"] == "Template": + value = Template(value["data"]) + # For future custom types we can follow a similar pattern + # elif value["__type__"] == "SomeOtherType": + # value = SomeOtherType.some_constructor(value["data"]) default_value = getattr(target_class, key, None) setattr(obj, key, value or default_value) return obj diff --git a/tests/helper_classes/test_json_serializable.py b/tests/helper_classes/test_json_serializable.py index c99bfd3e..5153d122 100644 --- a/tests/helper_classes/test_json_serializable.py +++ b/tests/helper_classes/test_json_serializable.py @@ -1,8 +1,9 @@ import random import unittest +from string import Template from embedchain import App -from embedchain.config import AppConfig +from embedchain.config import AppConfig, BaseLlmConfig from embedchain.helper.json_serializable import (JSONSerializable, register_deserializable) @@ -69,3 +70,11 @@ class TestJsonSerializable(unittest.TestCase): self.assertEqual(random_id, new_app.config.id) # We have proven that a nested class (app.config) can be serialized and deserialized just the same. # TODO: test deeper recursion + + def test_special_subclasses(self): + """Test special subclasses that are not serializable by default.""" + # Template + config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history.")) + s = config.serialize() + new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s) + self.assertEqual(config.template.template, new_config.template.template)