Coverage for src/bwd/bwd_random.py: 98%

50 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-12 11:56 +0000

1"""Balancing Walk Design with reversion to Bernoulli randomization.""" 

2 

3import numpy as np 

4 

5from .exceptions import SampleSizeExpendedError 

6 

7SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"] 

8 

9 

10class BWDRandom(object): 

11 """**The Balancing Walk Design with Reversion to Bernoulli Randomization** 

12 

13 This is an algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025). 

14 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In 

15 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of 

16 treatment conditional on history will be: 

17 

18 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$ 

19 

20 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and 

21 $\\alpha$ is the normalizing constant which ensures the probability is well-formed. 

22 

23 !!! important "If $|x \\cdot w| > \\alpha$" 

24 All future units will be assigned by complete randomization. 

25 """ 

26 

27 def __init__( 

28 self, 

29 N: int, 

30 D: int, 

31 delta: float = 0.05, 

32 q: float = 0.5, 

33 intercept: bool = True, 

34 phi: float = 1, 

35 ) -> None: 

36 """ 

37 Parameters 

38 ---------- 

39 N : int 

40 Total number of points 

41 D : int 

42 Dimension of the data 

43 delta : float, optional 

44 Probability of failure, by default 0.05 

45 q : float, optional 

46 Target marginal probability of treatment, by default 0.5 

47 intercept : bool, optional 

48 Whether an intercept term be added to covariate profiles, by default True 

49 phi : float, optional 

50 Robustness parameter. A value of 1 focuses entirely on balance, while a value 

51 approaching zero does pure randomization, by default 1 

52 """ 

53 self.q = q 

54 self.intercept = intercept 

55 self.delta = delta 

56 self.N = N 

57 self.D = D + int(self.intercept) 

58 

59 self.value_plus = 2 * (1 - self.q) 

60 self.value_minus = -2 * self.q 

61 self.phi = phi 

62 self.reset() 

63 

64 def set_alpha(self, N: int) -> None: 

65 """Set normalizing constant for remaining N units 

66 

67 Parameters 

68 ---------- 

69 N : int 

70 Number of units remaining in the sample 

71 

72 Raises 

73 ------ 

74 SampleSizeExpendedError 

75 If N is negative 

76 """ 

77 if N < 0: 

78 raise SampleSizeExpendedError() 

79 self.alpha = -1 

80 

81 def assign_next(self, x: np.ndarray) -> int: 

82 """Assign treatment to the next point 

83 

84 Parameters 

85 ---------- 

86 x : np.ndarray 

87 Covariate profile of unit to assign treatment 

88 

89 Returns 

90 ------- 

91 np.ndarray 

92 Treatment assignment (0 or 1) 

93 """ 

94 if self.intercept: 

95 x = np.concatenate(([1], x)) 

96 dot = x @ self.w_i 

97 if abs(dot) > self.alpha: 

98 self.w_i = np.zeros((self.D,)) 

99 self.set_alpha(self.N - self.iterations) 

100 dot = 0.0 

101 

102 p_i = self.q * (1 - self.phi * dot / self.alpha) 

103 

104 if np.random.rand() < p_i: 

105 value = self.value_plus 

106 assignment = 1 

107 else: 

108 value = self.value_minus 

109 assignment = -1 

110 self.w_i += value * x 

111 self.iterations += 1 

112 return int((assignment + 1) / 2) 

113 

114 def assign_all(self, X: np.ndarray) -> np.ndarray: 

115 """Assign all points 

116 

117 This assigns units to treatment in the offline setting in which all covariate 

118 profiles are available prior to assignment. The algorithm assigns as if units 

119 were still only observed in a stream. 

120 

121 Parameters 

122 ---------- 

123 X : np.ndarray 

124 Array of size n × d of covariate profiles 

125 

126 Returns 

127 ------- 

128 np.ndarray 

129 Array of treatment assignments 

130 """ 

131 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

132 

133 @property 

134 def definition(self): 

135 """Get the definition parameters of the balancer 

136 

137 Returns 

138 ------- 

139 dict 

140 Dictionary containing N, D, delta, q, intercept, and phi 

141 """ 

142 return { 

143 "N": self.N, 

144 "D": self.D, 

145 "delta": self.delta, 

146 "q": self.q, 

147 "intercept": self.intercept, 

148 "phi": self.phi, 

149 } 

150 

151 @property 

152 def state(self): 

153 """Get the current state of the balancer 

154 

155 Returns 

156 ------- 

157 dict 

158 Dictionary containing w_i and iterations 

159 """ 

160 return {"w_i": self.w_i, "iterations": self.iterations} 

161 

162 def update_state(self, w_i, iterations): 

163 """Update the state of the balancer 

164 

165 Parameters 

166 ---------- 

167 w_i : array-like 

168 Current imbalance vector 

169 iterations : int 

170 Current iteration count 

171 """ 

172 self.w_i = np.array(w_i) 

173 self.iterations = iterations 

174 

175 def reset(self): 

176 """Reset the balancer to initial state 

177 

178 Resets the imbalance vector to zeros, initializes alpha, and sets iterations to 0. 

179 """ 

180 self.w_i = np.zeros((self.D,)) 

181 self.alpha = np.log(2 * self.N / self.delta) * min(1 / self.q, 9.32) 

182 self.iterations = 0