Coverage for src/bwd/bwd_random.py: 98%
50 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
1import numpy as np
2from .exceptions import SampleSizeExpendedError
5SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"]
8class BWDRandom(object):
9 """**The Balancing Walk Design with Reversion to Bernoulli Randomization**
11 This is an algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025).
12 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In
13 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of
14 treatment conditional on history will be:
16 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$
18 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and
19 $\\alpha$ is the normalizing constant which ensures the probability is well-formed.
21 !!! important "If $|x \\cdot w| > \\alpha$"
22 All future units will be assigned by complete randomization.
23 """
25 def __init__(
26 self,
27 N: int,
28 D: int,
29 delta: float = 0.05,
30 q: float = 0.5,
31 intercept: bool = True,
32 phi: float = 1,
33 ) -> None:
34 """Initialize the object
36 Arguments:
37 N: total number of points
38 D: dimension of the data
39 delta: probability of failure
40 q: Target marginal probability of treatment
41 intercept: Whether an intercept term be added to covariate profiles
42 phi: Robustness parameter. A value of 1 focuses entirely on balance, while a value
43 approaching zero does pure randomization.
44 """
45 self.q = q
46 self.intercept = intercept
47 self.delta = delta
48 self.N = N
49 self.D = D + int(self.intercept)
51 self.value_plus = 2 * (1 - self.q)
52 self.value_minus = -2 * self.q
53 self.phi = phi
54 self.reset()
56 def set_alpha(self, N: int) -> None:
57 """Set normalizing constant for remaining N units
59 Args:
60 N: Number of units remaining in the sample
61 """
62 if N < 0:
63 raise SampleSizeExpendedError()
64 self.alpha = -1
66 def assign_next(self, x: np.ndarray) -> np.ndarray:
67 """Assign treatment to the next point
69 Args:
70 x: covariate profile of unit to assign treatment
71 """
72 if self.intercept:
73 x = np.concatenate(([1], x))
74 dot = x @ self.w_i
75 if abs(dot) > self.alpha:
76 self.w_i = np.zeros((self.D,))
77 self.set_alpha(self.N - self.iterations)
78 dot = 0.0
80 p_i = self.q * (1 - self.phi * dot / self.alpha)
82 if np.random.rand() < p_i:
83 value = self.value_plus
84 assignment = 1
85 else:
86 value = self.value_minus
87 assignment = -1
88 self.w_i += value * x
89 self.iterations += 1
90 return int((assignment + 1) / 2)
92 def assign_all(self, X: np.ndarray) -> np.ndarray:
93 """Assign all points
95 This assigns units to treatment in the offline setting in which all covariate
96 profiles are available prior to assignment. The algorithm assigns as if units
97 were still only observed in a stream.
99 Args:
100 X: array of size n × d of covariate profiles
101 """
102 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
104 @property
105 def definition(self):
106 return {
107 "N": self.N,
108 "D": self.D,
109 "delta": self.delta,
110 "q": self.q,
111 "intercept": self.intercept,
112 "phi": self.phi,
113 }
115 @property
116 def state(self):
117 return {"w_i": self.w_i, "iterations": self.iterations}
119 def update_state(self, w_i, iterations):
120 self.w_i = np.array(w_i)
121 self.iterations = iterations
123 def reset(self):
124 self.w_i = np.zeros((self.D,))
125 self.alpha = np.log(2 * self.N / self.delta) * min(1 / self.q, 9.32)
126 self.iterations = 0