Coverage for src/bwd/online.py: 41%

29 statements  

« prev     ^ index     » next       coverage.py v7.11.3, created at 2025-11-12 11:56 +0000

1"""Online balancer wrapper with automatic sample size expansion.""" 

2 

3import numpy as np 

4 

5from .exceptions import SampleSizeExpendedError 

6 

7 

8class Online(object): 

9 """Online balancer wrapper with automatic sample size expansion 

10 

11 This wrapper allows a balancer to operate in an online setting where the total 

12 sample size is not known in advance. When the sample size is exceeded, it 

13 automatically doubles the sample size while preserving the current state. 

14 """ 

15 

16 def __init__(self, cls, **kwargs): 

17 """ 

18 Parameters 

19 ---------- 

20 cls : class 

21 The balancer class to wrap (e.g., BWD, BWDRandom, MultiBWD) 

22 **kwargs : dict 

23 Keyword arguments to pass to the balancer class constructor. 

24 If N is not provided, it defaults to 1. 

25 """ 

26 kwargs["N"] = kwargs.get("N", 1) 

27 self.cls = cls 

28 self.balancer = cls(**kwargs) 

29 

30 def assign_next(self, x: np.ndarray) -> np.ndarray: 

31 """Assign treatment to the next point with automatic expansion 

32 

33 If the sample size is exceeded, automatically doubles the sample size 

34 and continues from the current state. 

35 

36 Parameters 

37 ---------- 

38 x : np.ndarray 

39 Covariate profile of unit to assign treatment 

40 

41 Returns 

42 ------- 

43 np.ndarray 

44 Treatment assignment 

45 """ 

46 try: 

47 return self.balancer.assign_next(x) 

48 except SampleSizeExpendedError: 

49 bal_def = self.balancer.definition 

50 bal_state = self.balancer.state 

51 bal_def["N"] = bal_def["N"] * 2 

52 self.balancer = self.cls(**bal_def) 

53 self.balancer.update_state(**bal_state) 

54 return self.balancer.assign_next(x) 

55 

56 def assign_all(self, X: np.ndarray) -> np.ndarray: 

57 """Assign all points 

58 

59 This assigns units to treatment in the offline setting in which all covariate 

60 profiles are available prior to assignment. The algorithm assigns as if units 

61 were still only observed in a stream. 

62 

63 Parameters 

64 ---------- 

65 X : np.ndarray 

66 Array of size n × d of covariate profiles 

67 

68 Returns 

69 ------- 

70 np.ndarray 

71 Array of treatment assignments 

72 """ 

73 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

74 

75 @property 

76 def definition(self): 

77 """Get the definition parameters of the wrapped balancer 

78 

79 Returns 

80 ------- 

81 dict 

82 Dictionary containing the balancer class and all definition parameters 

83 """ 

84 return {"cls": self.cls, **self.balancer.definition} 

85 

86 @property 

87 def state(self): 

88 """Get the current state of the wrapped balancer 

89 

90 Returns 

91 ------- 

92 dict 

93 Dictionary containing the current state 

94 """ 

95 return self.balancer.state 

96 

97 def update_state(self, **kwargs): 

98 """Update the state of the wrapped balancer 

99 

100 Parameters 

101 ---------- 

102 **kwargs : dict 

103 State parameters to update 

104 """ 

105 self.balancer.update_state(**kwargs) 

106 

107 def reset(self): 

108 """Reset the wrapped balancer to initial state""" 

109 self.balancer.reset()