Coverage for src/bwd/serialization.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-12 11:56 +0000

1"""Serialization utilities for BWD objects.""" 

2 

3import json 

4 

5import numpy as np 

6 

7from .bwd import BWD 

8from .bwd_random import BWDRandom 

9from .multi_bwd import MultiBWD 

10from .online import Online 

11 

12name2class = { 

13 "BWD": BWD, 

14 "BWDRandom": BWDRandom, 

15 "MultiBWD": MultiBWD, 

16 "Online": Online, 

17} 

18 

19 

20def normalize(to_serialize): 

21 """Normalize data structures for JSON serialization 

22 

23 Recursively converts numpy arrays to lists and normalizes nested dictionaries 

24 to ensure all data types are JSON-serializable. 

25 

26 Parameters 

27 ---------- 

28 to_serialize : dict 

29 Dictionary containing data to normalize 

30 

31 Returns 

32 ------- 

33 dict 

34 Normalized dictionary with JSON-compatible types 

35 """ 

36 result = {} 

37 for k, v in to_serialize.items(): 

38 if isinstance(v, np.ndarray): 

39 v = v.tolist() 

40 if isinstance(v, dict): 

41 v = normalize(v) 

42 result[k] = v 

43 return result 

44 

45 

46def serialize(obj): 

47 """Serialize a balancer object to JSON string 

48 

49 Serializes the balancer's definition and state to a JSON-formatted string 

50 that can be saved and later deserialized. 

51 

52 Parameters 

53 ---------- 

54 obj : BWD, BWDRandom, MultiBWD, or Online 

55 The balancer object to serialize 

56 

57 Returns 

58 ------- 

59 str 

60 JSON string representation of the object 

61 """ 

62 return json.dumps( 

63 { 

64 str(type(obj).__name__): { 

65 "definition": normalize(obj.definition), 

66 "state": normalize(obj.state), 

67 } 

68 } 

69 ) 

70 

71 

72def deserialize(json_str): 

73 """Deserialize a balancer object from JSON string 

74 

75 Reconstructs a balancer object from its serialized JSON representation, 

76 restoring both its definition and state. 

77 

78 Parameters 

79 ---------- 

80 json_str : str 

81 JSON string containing the serialized balancer 

82 

83 Returns 

84 ------- 

85 BWD, BWDRandom, MultiBWD, or Online 

86 The deserialized balancer object with restored state 

87 """ 

88 defs = json.loads(json_str) 

89 cls_name = list(defs.keys())[0] 

90 defs = defs[cls_name] 

91 

92 bal_object = name2class[cls_name](**defs["definition"]) 

93 bal_object.update_state(**defs["state"]) 

94 return bal_object