Coverage for src/bwd/bwd_random.py: 98%

50 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-08-19 16:45 +0000

1import numpy as np 

2from .exceptions import SampleSizeExpendedError 

3 

4 

5SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"] 

6 

7 

8class BWDRandom(object): 

9 """**The Balancing Walk Design with Reversion to Bernoulli Randomization** 

10 

11 This is an algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025). 

12 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In 

13 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of 

14 treatment conditional on history will be: 

15 

16 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$ 

17 

18 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and 

19 $\\alpha$ is the normalizing constant which ensures the probability is well-formed. 

20 

21 !!! important "If $|x \\cdot w| > \\alpha$" 

22 All future units will be assigned by complete randomization. 

23 """ 

24 

25 def __init__( 

26 self, 

27 N: int, 

28 D: int, 

29 delta: float = 0.05, 

30 q: float = 0.5, 

31 intercept: bool = True, 

32 phi: float = 1, 

33 ) -> None: 

34 """Initialize the object 

35 

36 Arguments: 

37 N: total number of points 

38 D: dimension of the data 

39 delta: probability of failure 

40 q: Target marginal probability of treatment 

41 intercept: Whether an intercept term be added to covariate profiles 

42 phi: Robustness parameter. A value of 1 focuses entirely on balance, while a value 

43 approaching zero does pure randomization. 

44 """ 

45 self.q = q 

46 self.intercept = intercept 

47 self.delta = delta 

48 self.N = N 

49 self.D = D + int(self.intercept) 

50 

51 self.value_plus = 2 * (1 - self.q) 

52 self.value_minus = -2 * self.q 

53 self.phi = phi 

54 self.reset() 

55 

56 def set_alpha(self, N: int) -> None: 

57 """Set normalizing constant for remaining N units 

58 

59 Args: 

60 N: Number of units remaining in the sample 

61 """ 

62 if N < 0: 

63 raise SampleSizeExpendedError() 

64 self.alpha = -1 

65 

66 def assign_next(self, x: np.ndarray) -> np.ndarray: 

67 """Assign treatment to the next point 

68 

69 Args: 

70 x: covariate profile of unit to assign treatment 

71 """ 

72 if self.intercept: 

73 x = np.concatenate(([1], x)) 

74 dot = x @ self.w_i 

75 if abs(dot) > self.alpha: 

76 self.w_i = np.zeros((self.D,)) 

77 self.set_alpha(self.N - self.iterations) 

78 dot = 0.0 

79 

80 p_i = self.q * (1 - self.phi * dot / self.alpha) 

81 

82 if np.random.rand() < p_i: 

83 value = self.value_plus 

84 assignment = 1 

85 else: 

86 value = self.value_minus 

87 assignment = -1 

88 self.w_i += value * x 

89 self.iterations += 1 

90 return int((assignment + 1) / 2) 

91 

92 def assign_all(self, X: np.ndarray) -> np.ndarray: 

93 """Assign all points 

94 

95 This assigns units to treatment in the offline setting in which all covariate 

96 profiles are available prior to assignment. The algorithm assigns as if units 

97 were still only observed in a stream. 

98 

99 Args: 

100 X: array of size n × d of covariate profiles 

101 """ 

102 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

103 

104 @property 

105 def definition(self): 

106 return { 

107 "N": self.N, 

108 "D": self.D, 

109 "delta": self.delta, 

110 "q": self.q, 

111 "intercept": self.intercept, 

112 "phi": self.phi, 

113 } 

114 

115 @property 

116 def state(self): 

117 return {"w_i": self.w_i, "iterations": self.iterations} 

118 

119 def update_state(self, w_i, iterations): 

120 self.w_i = np.array(w_i) 

121 self.iterations = iterations 

122 

123 def reset(self): 

124 self.w_i = np.zeros((self.D,)) 

125 self.alpha = np.log(2 * self.N / self.delta) * min(1 / self.q, 9.32) 

126 self.iterations = 0