Coverage for src/bwd/serialization.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-08-19 16:45 +0000

1import numpy as np 

2import json 

3 

4from .bwd import BWD 

5from .bwd_random import BWDRandom 

6from .multi_bwd import MultiBWD 

7from .online import Online 

8 

9name2class = { 

10 "BWD": BWD, 

11 "BWDRandom": BWDRandom, 

12 "MultiBWD": MultiBWD, 

13 "Online": Online, 

14} 

15 

16 

17def normalize(to_serialize): 

18 result = {} 

19 for k, v in to_serialize.items(): 

20 if isinstance(v, np.ndarray): 

21 v = v.tolist() 

22 if isinstance(v, dict): 

23 v = normalize(v) 

24 result[k] = v 

25 return result 

26 

27 

28def serialize(obj): 

29 return json.dumps( 

30 { 

31 str(type(obj).__name__): { 

32 "definition": normalize(obj.definition), 

33 "state": normalize(obj.state), 

34 } 

35 } 

36 ) 

37 

38 

39def deserialize(str): 

40 defs = json.loads(str) 

41 cls_name = list(defs.keys())[0] 

42 defs = defs[cls_name] 

43 

44 object = name2class[cls_name](**defs["definition"]) 

45 object.update_state(**defs["state"]) 

46 return object