Coverage for src/bwd/multi_bwd.py: 98%
81 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
1from typing import Union
2from collections.abc import Iterable
3from .bwd import BWD
5import numpy as np
8def _left(i):
9 return 2 * i + 1
12def _right(i):
13 return 2 * (i + 1)
16def _parent(i):
17 return int(np.floor((i - 1) / 2))
20class MultiBWD(object):
21 """**The Multi-treatment Balancing Walk Design with Restarts**
23 This method implements an extension to the Balancing Walk Design to balance
24 across multiple treatments. It accomplishes this by constructing a binary tree.
25 At each node in the binary tree, it balanced between the treatment groups on the
26 left and the right. Thus it ensures balance between any pair of treatment groups.
27 """
29 def __init__(
30 self,
31 N: int,
32 D: int,
33 delta: float = 0.05,
34 q: Union[float, Iterable] = 0.5,
35 intercept: bool = True,
36 phi: float = 1.0,
37 ):
38 """
39 Args:
40 N: total number of points
41 D: dimension of the data
42 delta: probability of failure
43 q: Target marginal probability of treatment
44 intercept: Whether an intercept term be added to covariate profiles
45 phi: Robustness parameter. A value of 1 focuses entirely on balance, while a value
46 approaching zero does pure randomization.
47 """
48 self.N = N
49 self.D = D
50 self.delta = delta
51 self.intercept = intercept
52 self.phi = phi
54 if isinstance(q, float):
55 q = q if q < 0.5 else 1 - q
56 self.qs = [1 - q, q]
57 self.classes = [0, 1]
58 elif isinstance(q, Iterable):
59 self.qs = [pr / sum(q) for pr in q]
60 self.classes = [i for i, q in enumerate(self.qs)]
61 num_groups = len(self.qs)
62 self.K = num_groups - 1
63 self.intercept = intercept
65 num_levels = int(np.ceil(np.log2(num_groups)))
66 num_leaves = int(np.power(2, num_levels))
67 extra_leaves = num_leaves - num_groups
68 num_nodes = int(np.power(2, num_levels + 1) - 1)
69 self.nodes = [None] * num_nodes
70 self.weights = [None] * num_nodes
72 trt_by_leaf = []
73 num_leaves_by_trt = []
74 for trt in range(num_groups):
75 if len(trt_by_leaf) % 2 == 0 and extra_leaves > 0:
76 num_trt = 2 * (int(np.floor((extra_leaves - 1) / 2)) + 1)
77 extra_leaves -= num_trt - 1
78 else:
79 num_trt = 1
80 trt_by_leaf += [trt] * num_trt
81 num_leaves_by_trt.append(num_trt)
83 for leaf, trt in enumerate(trt_by_leaf):
84 node = num_nodes - num_leaves + leaf
85 self.nodes[node] = trt
86 self.weights[node] = 1 / self.qs[trt] / num_leaves_by_trt[trt]
88 for cur_node in range(num_nodes)[::-1]:
89 if cur_node == 0:
90 break
91 parent = _parent(cur_node)
92 left = _left(parent)
93 right = _right(parent)
94 if self.nodes[left] == self.nodes[right]:
95 self.nodes[parent] = self.nodes[left]
96 self.weights[parent] = self.weights[left] + self.weights[right]
97 if self.nodes[left] is not None and self.nodes[right] is not None:
98 left_weight = self.weights[_left(parent)]
99 right_weight = self.weights[_right(parent)]
100 pr_right = right_weight / (left_weight + right_weight)
101 self.nodes[parent] = BWD(
102 N=N, D=D, intercept=intercept, delta=delta, q=pr_right, phi=phi
103 )
104 self.weights[parent] = left_weight + right_weight
106 def assign_next(self, x: np.ndarray) -> np.ndarray:
107 """Assign treatment to the next point
109 Args:
110 x: covariate profile of unit to assign treatment
111 """
112 cur_idx = 0
113 while isinstance(self.nodes[cur_idx], BWD):
114 assign = self.nodes[cur_idx].assign_next(x)
115 cur_idx = _right(cur_idx) if assign > 0 else _left(cur_idx)
116 return self.nodes[cur_idx]
118 def assign_all(self, X: np.ndarray) -> np.ndarray:
119 """Assign all points
121 This assigns units to treatment in the offline setting in which all covariate
122 profiles are available prior to assignment. The algorithm assigns as if units
123 were still only observed in a stream.
125 Args:
126 X: array of size n × d of covariate profiles
127 """
128 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
130 @property
131 def definition(self):
132 return {
133 "N": self.N,
134 "D": self.D,
135 "delta": self.delta,
136 "q": self.qs,
137 "intercept": self.intercept,
138 "phi": self.phi,
139 }
141 @property
142 def state(self):
143 return {
144 idx: node.state
145 for idx, node in enumerate(self.nodes)
146 if isinstance(node, BWD)
147 }
149 def update_state(self, **node_state_dict):
150 for node, state in node_state_dict.items():
151 self.nodes[int(node)].update_state(**state)
153 def reset(self):
154 for node in self.nodes:
155 node.reset()