Coverage for src/bwd/multi_bwd.py: 95%

87 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-12 11:56 +0000

1"""Multi-treatment Balancing Walk Design implementation.""" 

2 

3from collections.abc import Iterable 

4from typing import Any 

5 

6import numpy as np 

7 

8from .bwd import BWD 

9 

10 

11def _left(i): 

12 return 2 * i + 1 

13 

14 

15def _right(i): 

16 return 2 * (i + 1) 

17 

18 

19def _parent(i): 

20 return int(np.floor((i - 1) / 2)) 

21 

22 

23class MultiBWD(object): 

24 """**The Multi-treatment Balancing Walk Design with Restarts** 

25 

26 This method implements an extension to the Balancing Walk Design to balance 

27 across multiple treatments. It accomplishes this by constructing a binary tree. 

28 At each node in the binary tree, it balanced between the treatment groups on the 

29 left and the right. Thus it ensures balance between any pair of treatment groups. 

30 """ 

31 

32 def __init__( 

33 self, 

34 N: int, 

35 D: int, 

36 delta: float = 0.05, 

37 q: float | Iterable[float] = 0.5, 

38 intercept: bool = True, 

39 phi: float = 1.0, 

40 ): 

41 """ 

42 Parameters 

43 ---------- 

44 N : int 

45 Total number of points 

46 D : int 

47 Dimension of the data 

48 delta : float, optional 

49 Probability of failure, by default 0.05 

50 q : float | Iterable[float], optional 

51 Target marginal probability of treatment. Can be a single float for binary 

52 treatment or an iterable of probabilities for multiple treatments, by default 0.5 

53 intercept : bool, optional 

54 Whether an intercept term be added to covariate profiles, by default True 

55 phi : float, optional 

56 Robustness parameter. A value of 1 focuses entirely on balance, while a value 

57 approaching zero does pure randomization, by default 1.0 

58 """ 

59 self.N = N 

60 self.D = D 

61 self.delta = delta 

62 self.intercept = intercept 

63 self.phi = phi 

64 

65 if isinstance(q, float): 

66 q = q if q < 0.5 else 1 - q 

67 self.qs = [1 - q, q] 

68 self.classes = [0, 1] 

69 elif isinstance(q, Iterable): 

70 self.qs = [pr / sum(q) for pr in q] 

71 self.classes = [i for i, q in enumerate(self.qs)] 

72 num_groups = len(self.qs) 

73 self.K = num_groups - 1 

74 self.intercept = intercept 

75 

76 num_levels = int(np.ceil(np.log2(num_groups))) 

77 num_leaves = int(np.power(2, num_levels)) 

78 extra_leaves = num_leaves - num_groups 

79 num_nodes = int(np.power(2, num_levels + 1) - 1) 

80 

81 # Use dictionaries for type-stable storage 

82 # nodes: dict mapping index -> BWD object (for internal nodes) or int (for leaf nodes) 

83 # weights: dict mapping index -> float 

84 self.nodes: dict[int, BWD | int] = {} 

85 self.weights: dict[int, float] = {} 

86 

87 trt_by_leaf = [] 

88 num_leaves_by_trt = [] 

89 for trt in range(num_groups): 

90 if len(trt_by_leaf) % 2 == 0 and extra_leaves > 0: 

91 num_trt = 2 * (int(np.floor((extra_leaves - 1) / 2)) + 1) 

92 extra_leaves -= num_trt - 1 

93 else: 

94 num_trt = 1 

95 trt_by_leaf += [trt] * num_trt 

96 num_leaves_by_trt.append(num_trt) 

97 

98 # Initialize leaf nodes with treatment assignments 

99 for leaf, trt in enumerate(trt_by_leaf): 

100 node = num_nodes - num_leaves + leaf 

101 self.nodes[node] = trt 

102 self.weights[node] = 1 / self.qs[trt] / num_leaves_by_trt[trt] 

103 

104 # Build internal nodes from leaves up 

105 for cur_node in range(num_nodes)[::-1]: 

106 if cur_node == 0: 

107 break 

108 parent = _parent(cur_node) 

109 left = _left(parent) 

110 right = _right(parent) 

111 

112 # Skip if children haven't been initialized yet 

113 if left not in self.nodes or right not in self.nodes: 

114 continue 

115 

116 # If both children have the same treatment, propagate it up 

117 if self.nodes[left] == self.nodes[right]: 

118 self.nodes[parent] = self.nodes[left] 

119 self.weights[parent] = self.weights[left] + self.weights[right] 

120 # Otherwise, create a BWD balancer at this node 

121 else: 

122 left_weight = self.weights[left] 

123 right_weight = self.weights[right] 

124 pr_right = right_weight / (left_weight + right_weight) 

125 self.nodes[parent] = BWD( 

126 N=N, D=D, intercept=intercept, delta=delta, q=pr_right, phi=phi 

127 ) 

128 self.weights[parent] = left_weight + right_weight 

129 

130 def assign_next(self, x: np.ndarray) -> int: 

131 """Assign treatment to the next point 

132 

133 Parameters 

134 ---------- 

135 x : np.ndarray 

136 Covariate profile of unit to assign treatment 

137 

138 Returns 

139 ------- 

140 int 

141 Treatment assignment (treatment group index) 

142 """ 

143 cur_idx = 0 

144 while isinstance(self.nodes[cur_idx], BWD): 

145 assign = self.nodes[cur_idx].assign_next(x) 

146 cur_idx = _right(cur_idx) if assign > 0 else _left(cur_idx) 

147 # At this point, we've reached a leaf node which contains an int 

148 result = self.nodes[cur_idx] 

149 assert isinstance(result, int), "Leaf node must be an int" 

150 return result 

151 

152 def assign_all(self, X: np.ndarray) -> np.ndarray: 

153 """Assign all points 

154 

155 This assigns units to treatment in the offline setting in which all covariate 

156 profiles are available prior to assignment. The algorithm assigns as if units 

157 were still only observed in a stream. 

158 

159 Parameters 

160 ---------- 

161 X : np.ndarray 

162 Array of size n × d of covariate profiles 

163 

164 Returns 

165 ------- 

166 np.ndarray 

167 Array of treatment assignments 

168 """ 

169 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

170 

171 @property 

172 def definition(self): 

173 """Get the definition parameters of the balancer 

174 

175 Returns 

176 ------- 

177 dict 

178 Dictionary containing N, D, delta, q, intercept, and phi 

179 """ 

180 return { 

181 "N": self.N, 

182 "D": self.D, 

183 "delta": self.delta, 

184 "q": self.qs, 

185 "intercept": self.intercept, 

186 "phi": self.phi, 

187 } 

188 

189 @property 

190 def state(self): 

191 """Get the current state of all BWD nodes in the tree 

192 

193 Returns 

194 ------- 

195 dict 

196 Dictionary mapping node indices to their states 

197 """ 

198 return { 

199 idx: node.state for idx, node in self.nodes.items() if isinstance(node, BWD) 

200 } 

201 

202 def update_state(self, **node_state_dict: Any) -> None: 

203 """Update the state of BWD nodes in the tree 

204 

205 Parameters 

206 ---------- 

207 **node_state_dict : dict 

208 Dictionary mapping node indices (as strings) to state dictionaries 

209 """ 

210 for node, state in node_state_dict.items(): 

211 node_obj = self.nodes[int(node)] 

212 if isinstance(node_obj, BWD): 

213 node_obj.update_state(**state) 

214 

215 def reset(self): 

216 """Reset all BWD nodes in the tree to initial state""" 

217 for node in self.nodes.values(): 

218 if isinstance(node, BWD): 

219 node.reset()