Coverage for src/bwd/bwd.py: 98%

50 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-08-19 16:45 +0000

1import numpy as np 

2from .exceptions import SampleSizeExpendedError 

3 

4 

5SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"] 

6 

7 

8class BWD(object): 

9 """**The Balancing Walk Design with Restarts** 

10 

11 This is the primary suggested algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025). 

12 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In 

13 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of 

14 treatment conditional on history will be: 

15 

16 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$ 

17 

18 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and 

19 $\\alpha$ is the normalizing constant which ensures the probability is well-formed. 

20 

21 !!! important "If $|x \\cdot w| > \\alpha$" 

22 A restart is performed by resetting the algorithm: 

23 

24 - $w$ is reset to the zero vector 

25 - $\\alpha$ is reset to a constant based on the number of units remaining in the sample 

26 """ 

27 

28 def __init__( 

29 self, 

30 N: int, 

31 D: int, 

32 delta: float = 0.05, 

33 q: float = 0.5, 

34 intercept: bool = True, 

35 phi: float = 1, 

36 ) -> None: 

37 """ 

38 Args: 

39 N: total number of points 

40 D: dimension of the data 

41 delta: probability of failure 

42 q: Target marginal probability of treatment 

43 intercept: Whether an intercept term be added to covariate profiles 

44 phi: Robustness parameter. A value of 1 focuses entirely on balance, while a value 

45 approaching zero does pure randomization. 

46 """ 

47 self.q = q 

48 self.intercept = intercept 

49 self.delta = delta 

50 self.N = N 

51 self.D = D + int(self.intercept) 

52 self.value_plus = 2 * (1 - self.q) 

53 self.value_minus = -2 * self.q 

54 self.phi = phi 

55 self.reset() 

56 

57 def set_alpha(self, N: int) -> None: 

58 """Set normalizing constant for remaining N units 

59 

60 Args: 

61 N: Number of units remaining in the sample 

62 """ 

63 if N < 0: 

64 raise SampleSizeExpendedError() 

65 self.alpha = np.log(2 * N / self.delta) * min(1 / self.q, 9.32) 

66 

67 def assign_next(self, x: np.ndarray) -> np.ndarray: 

68 """Assign treatment to the next point 

69 

70 Args: 

71 x: covariate profile of unit to assign treatment 

72 """ 

73 if self.intercept: 

74 x = np.concatenate(([1], x)) 

75 dot = x @ self.w_i 

76 if abs(dot) > self.alpha: 

77 self.w_i = np.zeros((self.D,)) 

78 self.set_alpha(self.N - self.iterations) 

79 dot = x @ self.w_i 

80 

81 p_i = self.q * (1 - self.phi * dot / self.alpha) 

82 

83 if np.random.rand() < p_i: 

84 value = self.value_plus 

85 assignment = 1 

86 else: 

87 value = self.value_minus 

88 assignment = -1 

89 self.w_i += value * x 

90 self.iterations += 1 

91 return int((assignment + 1) / 2) 

92 

93 def assign_all(self, X: np.ndarray) -> np.ndarray: 

94 """Assign all points 

95 

96 This assigns units to treatment in the offline setting in which all covariate 

97 profiles are available prior to assignment. The algorithm assigns as if units 

98 were still only observed in a stream. 

99 

100 Args: 

101 X: array of size n × d of covariate profiles 

102 """ 

103 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

104 

105 @property 

106 def definition(self): 

107 return { 

108 "N": self.N, 

109 "D": self.D, 

110 "delta": self.delta, 

111 "q": self.q, 

112 "intercept": self.intercept, 

113 "phi": self.phi, 

114 } 

115 

116 @property 

117 def state(self): 

118 return {"w_i": self.w_i, "iterations": self.iterations} 

119 

120 def update_state(self, w_i, iterations): 

121 self.w_i = np.array(w_i) 

122 self.iterations = iterations 

123 

124 def reset(self): 

125 self.w_i = np.zeros((self.D,)) 

126 self.set_alpha(self.N) 

127 self.iterations = 0