Coverage for src/bwd/bwd.py: 98%
50 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
1import numpy as np
2from .exceptions import SampleSizeExpendedError
5SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"]
8class BWD(object):
9 """**The Balancing Walk Design with Restarts**
11 This is the primary suggested algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025).
12 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In
13 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of
14 treatment conditional on history will be:
16 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$
18 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and
19 $\\alpha$ is the normalizing constant which ensures the probability is well-formed.
21 !!! important "If $|x \\cdot w| > \\alpha$"
22 A restart is performed by resetting the algorithm:
24 - $w$ is reset to the zero vector
25 - $\\alpha$ is reset to a constant based on the number of units remaining in the sample
26 """
28 def __init__(
29 self,
30 N: int,
31 D: int,
32 delta: float = 0.05,
33 q: float = 0.5,
34 intercept: bool = True,
35 phi: float = 1,
36 ) -> None:
37 """
38 Args:
39 N: total number of points
40 D: dimension of the data
41 delta: probability of failure
42 q: Target marginal probability of treatment
43 intercept: Whether an intercept term be added to covariate profiles
44 phi: Robustness parameter. A value of 1 focuses entirely on balance, while a value
45 approaching zero does pure randomization.
46 """
47 self.q = q
48 self.intercept = intercept
49 self.delta = delta
50 self.N = N
51 self.D = D + int(self.intercept)
52 self.value_plus = 2 * (1 - self.q)
53 self.value_minus = -2 * self.q
54 self.phi = phi
55 self.reset()
57 def set_alpha(self, N: int) -> None:
58 """Set normalizing constant for remaining N units
60 Args:
61 N: Number of units remaining in the sample
62 """
63 if N < 0:
64 raise SampleSizeExpendedError()
65 self.alpha = np.log(2 * N / self.delta) * min(1 / self.q, 9.32)
67 def assign_next(self, x: np.ndarray) -> np.ndarray:
68 """Assign treatment to the next point
70 Args:
71 x: covariate profile of unit to assign treatment
72 """
73 if self.intercept:
74 x = np.concatenate(([1], x))
75 dot = x @ self.w_i
76 if abs(dot) > self.alpha:
77 self.w_i = np.zeros((self.D,))
78 self.set_alpha(self.N - self.iterations)
79 dot = x @ self.w_i
81 p_i = self.q * (1 - self.phi * dot / self.alpha)
83 if np.random.rand() < p_i:
84 value = self.value_plus
85 assignment = 1
86 else:
87 value = self.value_minus
88 assignment = -1
89 self.w_i += value * x
90 self.iterations += 1
91 return int((assignment + 1) / 2)
93 def assign_all(self, X: np.ndarray) -> np.ndarray:
94 """Assign all points
96 This assigns units to treatment in the offline setting in which all covariate
97 profiles are available prior to assignment. The algorithm assigns as if units
98 were still only observed in a stream.
100 Args:
101 X: array of size n × d of covariate profiles
102 """
103 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
105 @property
106 def definition(self):
107 return {
108 "N": self.N,
109 "D": self.D,
110 "delta": self.delta,
111 "q": self.q,
112 "intercept": self.intercept,
113 "phi": self.phi,
114 }
116 @property
117 def state(self):
118 return {"w_i": self.w_i, "iterations": self.iterations}
120 def update_state(self, w_i, iterations):
121 self.w_i = np.array(w_i)
122 self.iterations = iterations
124 def reset(self):
125 self.w_i = np.zeros((self.D,))
126 self.set_alpha(self.N)
127 self.iterations = 0