Coverage for src/bwd/bwd_random.py: 98%
50 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
1"""Balancing Walk Design with reversion to Bernoulli randomization."""
3import numpy as np
5from .exceptions import SampleSizeExpendedError
7SERIALIZED_ATTRIBUTES = ["N", "D", "delta", "q", "intercept", "phi"]
10class BWDRandom(object):
11 """**The Balancing Walk Design with Reversion to Bernoulli Randomization**
13 This is an algorithm from [Arbour et al (2022)](https://arxiv.org/abs/2203.02025).
14 At each step, it adjusts randomization probabilities to ensure that imbalance tends towards zero. In
15 particular, if current imbalance is w and the current covariate profile is $x$, then the probability of
16 treatment conditional on history will be:
18 $$p_i = q \\left(1 - \\phi \\frac{x \\cdot w}{\\alpha}\\right)$$
20 $q$ is the desired marginal probability, $\\phi$ is the parameter which controls robustness and
21 $\\alpha$ is the normalizing constant which ensures the probability is well-formed.
23 !!! important "If $|x \\cdot w| > \\alpha$"
24 All future units will be assigned by complete randomization.
25 """
27 def __init__(
28 self,
29 N: int,
30 D: int,
31 delta: float = 0.05,
32 q: float = 0.5,
33 intercept: bool = True,
34 phi: float = 1,
35 ) -> None:
36 """
37 Parameters
38 ----------
39 N : int
40 Total number of points
41 D : int
42 Dimension of the data
43 delta : float, optional
44 Probability of failure, by default 0.05
45 q : float, optional
46 Target marginal probability of treatment, by default 0.5
47 intercept : bool, optional
48 Whether an intercept term be added to covariate profiles, by default True
49 phi : float, optional
50 Robustness parameter. A value of 1 focuses entirely on balance, while a value
51 approaching zero does pure randomization, by default 1
52 """
53 self.q = q
54 self.intercept = intercept
55 self.delta = delta
56 self.N = N
57 self.D = D + int(self.intercept)
59 self.value_plus = 2 * (1 - self.q)
60 self.value_minus = -2 * self.q
61 self.phi = phi
62 self.reset()
64 def set_alpha(self, N: int) -> None:
65 """Set normalizing constant for remaining N units
67 Parameters
68 ----------
69 N : int
70 Number of units remaining in the sample
72 Raises
73 ------
74 SampleSizeExpendedError
75 If N is negative
76 """
77 if N < 0:
78 raise SampleSizeExpendedError()
79 self.alpha = -1
81 def assign_next(self, x: np.ndarray) -> int:
82 """Assign treatment to the next point
84 Parameters
85 ----------
86 x : np.ndarray
87 Covariate profile of unit to assign treatment
89 Returns
90 -------
91 np.ndarray
92 Treatment assignment (0 or 1)
93 """
94 if self.intercept:
95 x = np.concatenate(([1], x))
96 dot = x @ self.w_i
97 if abs(dot) > self.alpha:
98 self.w_i = np.zeros((self.D,))
99 self.set_alpha(self.N - self.iterations)
100 dot = 0.0
102 p_i = self.q * (1 - self.phi * dot / self.alpha)
104 if np.random.rand() < p_i:
105 value = self.value_plus
106 assignment = 1
107 else:
108 value = self.value_minus
109 assignment = -1
110 self.w_i += value * x
111 self.iterations += 1
112 return int((assignment + 1) / 2)
114 def assign_all(self, X: np.ndarray) -> np.ndarray:
115 """Assign all points
117 This assigns units to treatment in the offline setting in which all covariate
118 profiles are available prior to assignment. The algorithm assigns as if units
119 were still only observed in a stream.
121 Parameters
122 ----------
123 X : np.ndarray
124 Array of size n × d of covariate profiles
126 Returns
127 -------
128 np.ndarray
129 Array of treatment assignments
130 """
131 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
133 @property
134 def definition(self):
135 """Get the definition parameters of the balancer
137 Returns
138 -------
139 dict
140 Dictionary containing N, D, delta, q, intercept, and phi
141 """
142 return {
143 "N": self.N,
144 "D": self.D,
145 "delta": self.delta,
146 "q": self.q,
147 "intercept": self.intercept,
148 "phi": self.phi,
149 }
151 @property
152 def state(self):
153 """Get the current state of the balancer
155 Returns
156 -------
157 dict
158 Dictionary containing w_i and iterations
159 """
160 return {"w_i": self.w_i, "iterations": self.iterations}
162 def update_state(self, w_i, iterations):
163 """Update the state of the balancer
165 Parameters
166 ----------
167 w_i : array-like
168 Current imbalance vector
169 iterations : int
170 Current iteration count
171 """
172 self.w_i = np.array(w_i)
173 self.iterations = iterations
175 def reset(self):
176 """Reset the balancer to initial state
178 Resets the imbalance vector to zeros, initializes alpha, and sets iterations to 0.
179 """
180 self.w_i = np.zeros((self.D,))
181 self.alpha = np.log(2 * self.N / self.delta) * min(1 / self.q, 9.32)
182 self.iterations = 0