Coverage for src/bwd/multi_bwd.py: 95%
87 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
1"""Multi-treatment Balancing Walk Design implementation."""
3from collections.abc import Iterable
4from typing import Any
6import numpy as np
8from .bwd import BWD
11def _left(i):
12 return 2 * i + 1
15def _right(i):
16 return 2 * (i + 1)
19def _parent(i):
20 return int(np.floor((i - 1) / 2))
23class MultiBWD(object):
24 """**The Multi-treatment Balancing Walk Design with Restarts**
26 This method implements an extension to the Balancing Walk Design to balance
27 across multiple treatments. It accomplishes this by constructing a binary tree.
28 At each node in the binary tree, it balanced between the treatment groups on the
29 left and the right. Thus it ensures balance between any pair of treatment groups.
30 """
32 def __init__(
33 self,
34 N: int,
35 D: int,
36 delta: float = 0.05,
37 q: float | Iterable[float] = 0.5,
38 intercept: bool = True,
39 phi: float = 1.0,
40 ):
41 """
42 Parameters
43 ----------
44 N : int
45 Total number of points
46 D : int
47 Dimension of the data
48 delta : float, optional
49 Probability of failure, by default 0.05
50 q : float | Iterable[float], optional
51 Target marginal probability of treatment. Can be a single float for binary
52 treatment or an iterable of probabilities for multiple treatments, by default 0.5
53 intercept : bool, optional
54 Whether an intercept term be added to covariate profiles, by default True
55 phi : float, optional
56 Robustness parameter. A value of 1 focuses entirely on balance, while a value
57 approaching zero does pure randomization, by default 1.0
58 """
59 self.N = N
60 self.D = D
61 self.delta = delta
62 self.intercept = intercept
63 self.phi = phi
65 if isinstance(q, float):
66 q = q if q < 0.5 else 1 - q
67 self.qs = [1 - q, q]
68 self.classes = [0, 1]
69 elif isinstance(q, Iterable):
70 self.qs = [pr / sum(q) for pr in q]
71 self.classes = [i for i, q in enumerate(self.qs)]
72 num_groups = len(self.qs)
73 self.K = num_groups - 1
74 self.intercept = intercept
76 num_levels = int(np.ceil(np.log2(num_groups)))
77 num_leaves = int(np.power(2, num_levels))
78 extra_leaves = num_leaves - num_groups
79 num_nodes = int(np.power(2, num_levels + 1) - 1)
81 # Use dictionaries for type-stable storage
82 # nodes: dict mapping index -> BWD object (for internal nodes) or int (for leaf nodes)
83 # weights: dict mapping index -> float
84 self.nodes: dict[int, BWD | int] = {}
85 self.weights: dict[int, float] = {}
87 trt_by_leaf = []
88 num_leaves_by_trt = []
89 for trt in range(num_groups):
90 if len(trt_by_leaf) % 2 == 0 and extra_leaves > 0:
91 num_trt = 2 * (int(np.floor((extra_leaves - 1) / 2)) + 1)
92 extra_leaves -= num_trt - 1
93 else:
94 num_trt = 1
95 trt_by_leaf += [trt] * num_trt
96 num_leaves_by_trt.append(num_trt)
98 # Initialize leaf nodes with treatment assignments
99 for leaf, trt in enumerate(trt_by_leaf):
100 node = num_nodes - num_leaves + leaf
101 self.nodes[node] = trt
102 self.weights[node] = 1 / self.qs[trt] / num_leaves_by_trt[trt]
104 # Build internal nodes from leaves up
105 for cur_node in range(num_nodes)[::-1]:
106 if cur_node == 0:
107 break
108 parent = _parent(cur_node)
109 left = _left(parent)
110 right = _right(parent)
112 # Skip if children haven't been initialized yet
113 if left not in self.nodes or right not in self.nodes:
114 continue
116 # If both children have the same treatment, propagate it up
117 if self.nodes[left] == self.nodes[right]:
118 self.nodes[parent] = self.nodes[left]
119 self.weights[parent] = self.weights[left] + self.weights[right]
120 # Otherwise, create a BWD balancer at this node
121 else:
122 left_weight = self.weights[left]
123 right_weight = self.weights[right]
124 pr_right = right_weight / (left_weight + right_weight)
125 self.nodes[parent] = BWD(
126 N=N, D=D, intercept=intercept, delta=delta, q=pr_right, phi=phi
127 )
128 self.weights[parent] = left_weight + right_weight
130 def assign_next(self, x: np.ndarray) -> int:
131 """Assign treatment to the next point
133 Parameters
134 ----------
135 x : np.ndarray
136 Covariate profile of unit to assign treatment
138 Returns
139 -------
140 int
141 Treatment assignment (treatment group index)
142 """
143 cur_idx = 0
144 while isinstance(self.nodes[cur_idx], BWD):
145 assign = self.nodes[cur_idx].assign_next(x)
146 cur_idx = _right(cur_idx) if assign > 0 else _left(cur_idx)
147 # At this point, we've reached a leaf node which contains an int
148 result = self.nodes[cur_idx]
149 assert isinstance(result, int), "Leaf node must be an int"
150 return result
152 def assign_all(self, X: np.ndarray) -> np.ndarray:
153 """Assign all points
155 This assigns units to treatment in the offline setting in which all covariate
156 profiles are available prior to assignment. The algorithm assigns as if units
157 were still only observed in a stream.
159 Parameters
160 ----------
161 X : np.ndarray
162 Array of size n × d of covariate profiles
164 Returns
165 -------
166 np.ndarray
167 Array of treatment assignments
168 """
169 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
171 @property
172 def definition(self):
173 """Get the definition parameters of the balancer
175 Returns
176 -------
177 dict
178 Dictionary containing N, D, delta, q, intercept, and phi
179 """
180 return {
181 "N": self.N,
182 "D": self.D,
183 "delta": self.delta,
184 "q": self.qs,
185 "intercept": self.intercept,
186 "phi": self.phi,
187 }
189 @property
190 def state(self):
191 """Get the current state of all BWD nodes in the tree
193 Returns
194 -------
195 dict
196 Dictionary mapping node indices to their states
197 """
198 return {
199 idx: node.state for idx, node in self.nodes.items() if isinstance(node, BWD)
200 }
202 def update_state(self, **node_state_dict: Any) -> None:
203 """Update the state of BWD nodes in the tree
205 Parameters
206 ----------
207 **node_state_dict : dict
208 Dictionary mapping node indices (as strings) to state dictionaries
209 """
210 for node, state in node_state_dict.items():
211 node_obj = self.nodes[int(node)]
212 if isinstance(node_obj, BWD):
213 node_obj.update_state(**state)
215 def reset(self):
216 """Reset all BWD nodes in the tree to initial state"""
217 for node in self.nodes.values():
218 if isinstance(node, BWD):
219 node.reset()