Coverage for src/bwd/bwd.py: 98%
61 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
1"""Balancing Walk Design with Restarts implementation."""
3import numpy as np
5from .exceptions import SampleSizeExpendedError
7SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"]
10class BWD(object):
11 """**The Balancing Walk Design with Restarts**
13 This is the primary suggested algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025).
14 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In
15 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of
16 treatment conditional on history will be:
18 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$
20 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and
21 $\\alpha$ is the normalizing constant which ensures the probability is well-formed.
23 !!! important "If $|x \\cdot w| > \\alpha$"
24 A restart is performed by resetting the algorithm:
26 - $w$ is reset to the zero vector
27 - $\\alpha$ is reset to a constant based on the number of units remaining in the sample
28 """
30 q: float
31 intercept: bool
32 delta: float
33 N: int
34 D: int
35 value_plus: float
36 value_minus: float
37 phi: float
38 alpha: float
39 w_i: np.ndarray
40 iterations: int
42 def __init__(
43 self,
44 N: int,
45 D: int,
46 delta: float = 0.05,
47 q: float = 0.5,
48 intercept: bool = True,
49 phi: float = 1,
50 ) -> None:
51 """
52 Parameters
53 ----------
54 N : int
55 Total number of points
56 D : int
57 Dimension of the data
58 delta : float, optional
59 Probability of failure, by default 0.05
60 q : float, optional
61 Target marginal probability of treatment, by default 0.5
62 intercept : bool, optional
63 Whether an intercept term be added to covariate profiles, by default True
64 phi : float, optional
65 Robustness parameter. A value of 1 focuses entirely on balance, while a value
66 approaching zero does pure randomization, by default 1
67 """
68 self.q = q
69 self.intercept = intercept
70 self.delta = delta
71 self.N = N
72 self.D = D + int(self.intercept)
73 self.value_plus = 2 * (1 - self.q)
74 self.value_minus = -2 * self.q
75 self.phi = phi
76 self.reset()
78 def set_alpha(self, N: int) -> None:
79 """Set normalizing constant for remaining N units
81 Parameters
82 ----------
83 N : int
84 Number of units remaining in the sample
86 Raises
87 ------
88 SampleSizeExpendedError
89 If N is negative
90 """
91 if N < 0:
92 raise SampleSizeExpendedError()
93 self.alpha = np.log(2 * N / self.delta) * min(1 / self.q, 9.32)
95 def assign_next(self, x: np.ndarray) -> int:
96 """Assign treatment to the next point
98 Parameters
99 ----------
100 x : np.ndarray
101 Covariate profile of unit to assign treatment
103 Returns
104 -------
105 int
106 Treatment assignment (0 or 1)
107 """
108 if self.intercept:
109 x = np.concatenate(([1], x))
110 dot = x @ self.w_i
111 if abs(dot) > self.alpha:
112 self.w_i = np.zeros((self.D,))
113 self.set_alpha(self.N - self.iterations)
114 dot = x @ self.w_i
116 p_i = self.q * (1 - self.phi * dot / self.alpha)
118 if np.random.rand() < p_i:
119 value = self.value_plus
120 assignment = 1
121 else:
122 value = self.value_minus
123 assignment = -1
124 self.w_i += value * x
125 self.iterations += 1
126 return int((assignment + 1) / 2)
128 def assign_all(self, X: np.ndarray) -> np.ndarray:
129 """Assign all points
131 This assigns units to treatment in the offline setting in which all covariate
132 profiles are available prior to assignment. The algorithm assigns as if units
133 were still only observed in a stream.
135 Parameters
136 ----------
137 X : np.ndarray
138 Array of size n × d of covariate profiles
140 Returns
141 -------
142 np.ndarray
143 Array of treatment assignments
144 """
145 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
147 @property
148 def definition(self):
149 """Get the definition parameters of the balancer
151 Returns
152 -------
153 dict
154 Dictionary containing N, D, delta, q, intercept, and phi
155 """
156 return {
157 "N": self.N,
158 "D": self.D,
159 "delta": self.delta,
160 "q": self.q,
161 "intercept": self.intercept,
162 "phi": self.phi,
163 }
165 @property
166 def state(self):
167 """Get the current state of the balancer
169 Returns
170 -------
171 dict
172 Dictionary containing w_i and iterations
173 """
174 return {"w_i": self.w_i, "iterations": self.iterations}
176 def update_state(self, w_i, iterations):
177 """Update the state of the balancer
179 Parameters
180 ----------
181 w_i : array-like
182 Current imbalance vector
183 iterations : int
184 Current iteration count
185 """
186 self.w_i = np.array(w_i)
187 self.iterations = iterations
189 def reset(self):
190 """Reset the balancer to initial state
192 Resets the imbalance vector to zeros, reinitializes alpha, and sets iterations to 0.
193 """
194 self.w_i = np.zeros((self.D,))
195 self.set_alpha(self.N)
196 self.iterations = 0