Coverage for src/bwd/multi_bwd.py: 98%

81 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-08-19 16:45 +0000

1from typing import Union 

2from collections.abc import Iterable 

3from .bwd import BWD 

4 

5import numpy as np 

6 

7 

8def _left(i): 

9 return 2 * i + 1 

10 

11 

12def _right(i): 

13 return 2 * (i + 1) 

14 

15 

16def _parent(i): 

17 return int(np.floor((i - 1) / 2)) 

18 

19 

20class MultiBWD(object): 

21 """**The Multi-treatment Balancing Walk Design with Restarts** 

22 

23 This method implements an extension to the Balancing Walk Design to balance 

24 across multiple treatments. It accomplishes this by constructing a binary tree. 

25 At each node in the binary tree, it balanced between the treatment groups on the 

26 left and the right. Thus it ensures balance between any pair of treatment groups. 

27 """ 

28 

29 def __init__( 

30 self, 

31 N: int, 

32 D: int, 

33 delta: float = 0.05, 

34 q: Union[float, Iterable] = 0.5, 

35 intercept: bool = True, 

36 phi: float = 1.0, 

37 ): 

38 """ 

39 Args: 

40 N: total number of points 

41 D: dimension of the data 

42 delta: probability of failure 

43 q: Target marginal probability of treatment 

44 intercept: Whether an intercept term be added to covariate profiles 

45 phi: Robustness parameter. A value of 1 focuses entirely on balance, while a value 

46 approaching zero does pure randomization. 

47 """ 

48 self.N = N 

49 self.D = D 

50 self.delta = delta 

51 self.intercept = intercept 

52 self.phi = phi 

53 

54 if isinstance(q, float): 

55 q = q if q < 0.5 else 1 - q 

56 self.qs = [1 - q, q] 

57 self.classes = [0, 1] 

58 elif isinstance(q, Iterable): 

59 self.qs = [pr / sum(q) for pr in q] 

60 self.classes = [i for i, q in enumerate(self.qs)] 

61 num_groups = len(self.qs) 

62 self.K = num_groups - 1 

63 self.intercept = intercept 

64 

65 num_levels = int(np.ceil(np.log2(num_groups))) 

66 num_leaves = int(np.power(2, num_levels)) 

67 extra_leaves = num_leaves - num_groups 

68 num_nodes = int(np.power(2, num_levels + 1) - 1) 

69 self.nodes = [None] * num_nodes 

70 self.weights = [None] * num_nodes 

71 

72 trt_by_leaf = [] 

73 num_leaves_by_trt = [] 

74 for trt in range(num_groups): 

75 if len(trt_by_leaf) % 2 == 0 and extra_leaves > 0: 

76 num_trt = 2 * (int(np.floor((extra_leaves - 1) / 2)) + 1) 

77 extra_leaves -= num_trt - 1 

78 else: 

79 num_trt = 1 

80 trt_by_leaf += [trt] * num_trt 

81 num_leaves_by_trt.append(num_trt) 

82 

83 for leaf, trt in enumerate(trt_by_leaf): 

84 node = num_nodes - num_leaves + leaf 

85 self.nodes[node] = trt 

86 self.weights[node] = 1 / self.qs[trt] / num_leaves_by_trt[trt] 

87 

88 for cur_node in range(num_nodes)[::-1]: 

89 if cur_node == 0: 

90 break 

91 parent = _parent(cur_node) 

92 left = _left(parent) 

93 right = _right(parent) 

94 if self.nodes[left] == self.nodes[right]: 

95 self.nodes[parent] = self.nodes[left] 

96 self.weights[parent] = self.weights[left] + self.weights[right] 

97 if self.nodes[left] is not None and self.nodes[right] is not None: 

98 left_weight = self.weights[_left(parent)] 

99 right_weight = self.weights[_right(parent)] 

100 pr_right = right_weight / (left_weight + right_weight) 

101 self.nodes[parent] = BWD( 

102 N=N, D=D, intercept=intercept, delta=delta, q=pr_right, phi=phi 

103 ) 

104 self.weights[parent] = left_weight + right_weight 

105 

106 def assign_next(self, x: np.ndarray) -> np.ndarray: 

107 """Assign treatment to the next point 

108 

109 Args: 

110 x: covariate profile of unit to assign treatment 

111 """ 

112 cur_idx = 0 

113 while isinstance(self.nodes[cur_idx], BWD): 

114 assign = self.nodes[cur_idx].assign_next(x) 

115 cur_idx = _right(cur_idx) if assign > 0 else _left(cur_idx) 

116 return self.nodes[cur_idx] 

117 

118 def assign_all(self, X: np.ndarray) -> np.ndarray: 

119 """Assign all points 

120 

121 This assigns units to treatment in the offline setting in which all covariate 

122 profiles are available prior to assignment. The algorithm assigns as if units 

123 were still only observed in a stream. 

124 

125 Args: 

126 X: array of size n × d of covariate profiles 

127 """ 

128 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

129 

130 @property 

131 def definition(self): 

132 return { 

133 "N": self.N, 

134 "D": self.D, 

135 "delta": self.delta, 

136 "q": self.qs, 

137 "intercept": self.intercept, 

138 "phi": self.phi, 

139 } 

140 

141 @property 

142 def state(self): 

143 return { 

144 idx: node.state 

145 for idx, node in enumerate(self.nodes) 

146 if isinstance(node, BWD) 

147 } 

148 

149 def update_state(self, **node_state_dict): 

150 for node, state in node_state_dict.items(): 

151 self.nodes[int(node)].update_state(**state) 

152 

153 def reset(self): 

154 for node in self.nodes: 

155 node.reset()