Coverage for src/bwd/bwd.py: 98%

61 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-12 11:56 +0000

1"""Balancing Walk Design with Restarts implementation.""" 

2 

3import numpy as np 

4 

5from .exceptions import SampleSizeExpendedError 

6 

7SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"] 

8 

9 

10class BWD(object): 

11 """**The Balancing Walk Design with Restarts** 

12 

13 This is the primary suggested algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025). 

14 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In 

15 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of 

16 treatment conditional on history will be: 

17 

18 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$ 

19 

20 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and 

21 $\\alpha$ is the normalizing constant which ensures the probability is well-formed. 

22 

23 !!! important "If $|x \\cdot w| > \\alpha$" 

24 A restart is performed by resetting the algorithm: 

25 

26 - $w$ is reset to the zero vector 

27 - $\\alpha$ is reset to a constant based on the number of units remaining in the sample 

28 """ 

29 

30 q: float 

31 intercept: bool 

32 delta: float 

33 N: int 

34 D: int 

35 value_plus: float 

36 value_minus: float 

37 phi: float 

38 alpha: float 

39 w_i: np.ndarray 

40 iterations: int 

41 

42 def __init__( 

43 self, 

44 N: int, 

45 D: int, 

46 delta: float = 0.05, 

47 q: float = 0.5, 

48 intercept: bool = True, 

49 phi: float = 1, 

50 ) -> None: 

51 """ 

52 Parameters 

53 ---------- 

54 N : int 

55 Total number of points 

56 D : int 

57 Dimension of the data 

58 delta : float, optional 

59 Probability of failure, by default 0.05 

60 q : float, optional 

61 Target marginal probability of treatment, by default 0.5 

62 intercept : bool, optional 

63 Whether an intercept term be added to covariate profiles, by default True 

64 phi : float, optional 

65 Robustness parameter. A value of 1 focuses entirely on balance, while a value 

66 approaching zero does pure randomization, by default 1 

67 """ 

68 self.q = q 

69 self.intercept = intercept 

70 self.delta = delta 

71 self.N = N 

72 self.D = D + int(self.intercept) 

73 self.value_plus = 2 * (1 - self.q) 

74 self.value_minus = -2 * self.q 

75 self.phi = phi 

76 self.reset() 

77 

78 def set_alpha(self, N: int) -> None: 

79 """Set normalizing constant for remaining N units 

80 

81 Parameters 

82 ---------- 

83 N : int 

84 Number of units remaining in the sample 

85 

86 Raises 

87 ------ 

88 SampleSizeExpendedError 

89 If N is negative 

90 """ 

91 if N < 0: 

92 raise SampleSizeExpendedError() 

93 self.alpha = np.log(2 * N / self.delta) * min(1 / self.q, 9.32) 

94 

95 def assign_next(self, x: np.ndarray) -> int: 

96 """Assign treatment to the next point 

97 

98 Parameters 

99 ---------- 

100 x : np.ndarray 

101 Covariate profile of unit to assign treatment 

102 

103 Returns 

104 ------- 

105 int 

106 Treatment assignment (0 or 1) 

107 """ 

108 if self.intercept: 

109 x = np.concatenate(([1], x)) 

110 dot = x @ self.w_i 

111 if abs(dot) > self.alpha: 

112 self.w_i = np.zeros((self.D,)) 

113 self.set_alpha(self.N - self.iterations) 

114 dot = x @ self.w_i 

115 

116 p_i = self.q * (1 - self.phi * dot / self.alpha) 

117 

118 if np.random.rand() < p_i: 

119 value = self.value_plus 

120 assignment = 1 

121 else: 

122 value = self.value_minus 

123 assignment = -1 

124 self.w_i += value * x 

125 self.iterations += 1 

126 return int((assignment + 1) / 2) 

127 

128 def assign_all(self, X: np.ndarray) -> np.ndarray: 

129 """Assign all points 

130 

131 This assigns units to treatment in the offline setting in which all covariate 

132 profiles are available prior to assignment. The algorithm assigns as if units 

133 were still only observed in a stream. 

134 

135 Parameters 

136 ---------- 

137 X : np.ndarray 

138 Array of size n × d of covariate profiles 

139 

140 Returns 

141 ------- 

142 np.ndarray 

143 Array of treatment assignments 

144 """ 

145 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

146 

147 @property 

148 def definition(self): 

149 """Get the definition parameters of the balancer 

150 

151 Returns 

152 ------- 

153 dict 

154 Dictionary containing N, D, delta, q, intercept, and phi 

155 """ 

156 return { 

157 "N": self.N, 

158 "D": self.D, 

159 "delta": self.delta, 

160 "q": self.q, 

161 "intercept": self.intercept, 

162 "phi": self.phi, 

163 } 

164 

165 @property 

166 def state(self): 

167 """Get the current state of the balancer 

168 

169 Returns 

170 ------- 

171 dict 

172 Dictionary containing w_i and iterations 

173 """ 

174 return {"w_i": self.w_i, "iterations": self.iterations} 

175 

176 def update_state(self, w_i, iterations): 

177 """Update the state of the balancer 

178 

179 Parameters 

180 ---------- 

181 w_i : array-like 

182 Current imbalance vector 

183 iterations : int 

184 Current iteration count 

185 """ 

186 self.w_i = np.array(w_i) 

187 self.iterations = iterations 

188 

189 def reset(self): 

190 """Reset the balancer to initial state 

191 

192 Resets the imbalance vector to zeros, reinitializes alpha, and sets iterations to 0. 

193 """ 

194 self.w_i = np.zeros((self.D,)) 

195 self.set_alpha(self.N) 

196 self.iterations = 0