Coverage for src/bwd/serialization.py: 100%
25 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
1"""Serialization utilities for BWD objects."""
3import json
5import numpy as np
7from .bwd import BWD
8from .bwd_random import BWDRandom
9from .multi_bwd import MultiBWD
10from .online import Online
12name2class = {
13 "BWD": BWD,
14 "BWDRandom": BWDRandom,
15 "MultiBWD": MultiBWD,
16 "Online": Online,
17}
20def normalize(to_serialize):
21 """Normalize data structures for JSON serialization
23 Recursively converts numpy arrays to lists and normalizes nested dictionaries
24 to ensure all data types are JSON-serializable.
26 Parameters
27 ----------
28 to_serialize : dict
29 Dictionary containing data to normalize
31 Returns
32 -------
33 dict
34 Normalized dictionary with JSON-compatible types
35 """
36 result = {}
37 for k, v in to_serialize.items():
38 if isinstance(v, np.ndarray):
39 v = v.tolist()
40 if isinstance(v, dict):
41 v = normalize(v)
42 result[k] = v
43 return result
46def serialize(obj):
47 """Serialize a balancer object to JSON string
49 Serializes the balancer's definition and state to a JSON-formatted string
50 that can be saved and later deserialized.
52 Parameters
53 ----------
54 obj : BWD, BWDRandom, MultiBWD, or Online
55 The balancer object to serialize
57 Returns
58 -------
59 str
60 JSON string representation of the object
61 """
62 return json.dumps(
63 {
64 str(type(obj).__name__): {
65 "definition": normalize(obj.definition),
66 "state": normalize(obj.state),
67 }
68 }
69 )
72def deserialize(json_str):
73 """Deserialize a balancer object from JSON string
75 Reconstructs a balancer object from its serialized JSON representation,
76 restoring both its definition and state.
78 Parameters
79 ----------
80 json_str : str
81 JSON string containing the serialized balancer
83 Returns
84 -------
85 BWD, BWDRandom, MultiBWD, or Online
86 The deserialized balancer object with restored state
87 """
88 defs = json.loads(json_str)
89 cls_name = list(defs.keys())[0]
90 defs = defs[cls_name]
92 bal_object = name2class[cls_name](**defs["definition"])
93 bal_object.update_state(**defs["state"])
94 return bal_object