Coverage for src/bwd/serialization.py: 100%
25 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
1import numpy as np
2import json
4from .bwd import BWD
5from .bwd_random import BWDRandom
6from .multi_bwd import MultiBWD
7from .online import Online
9name2class = {
10 "BWD": BWD,
11 "BWDRandom": BWDRandom,
12 "MultiBWD": MultiBWD,
13 "Online": Online,
14}
17def normalize(to_serialize):
18 result = {}
19 for k, v in to_serialize.items():
20 if isinstance(v, np.ndarray):
21 v = v.tolist()
22 if isinstance(v, dict):
23 v = normalize(v)
24 result[k] = v
25 return result
28def serialize(obj):
29 return json.dumps(
30 {
31 str(type(obj).__name__): {
32 "definition": normalize(obj.definition),
33 "state": normalize(obj.state),
34 }
35 }
36 )
39def deserialize(str):
40 defs = json.loads(str)
41 cls_name = list(defs.keys())[0]
42 defs = defs[cls_name]
44 object = name2class[cls_name](**defs["definition"])
45 object.update_state(**defs["state"])
46 return object