Coverage for src/bwd/online.py: 41%
29 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-12 11:56 +0000
1"""Online balancer wrapper with automatic sample size expansion."""
3import numpy as np
5from .exceptions import SampleSizeExpendedError
8class Online(object):
9 """Online balancer wrapper with automatic sample size expansion
11 This wrapper allows a balancer to operate in an online setting where the total
12 sample size is not known in advance. When the sample size is exceeded, it
13 automatically doubles the sample size while preserving the current state.
14 """
16 def __init__(self, cls, **kwargs):
17 """
18 Parameters
19 ----------
20 cls : class
21 The balancer class to wrap (e.g., BWD, BWDRandom, MultiBWD)
22 **kwargs : dict
23 Keyword arguments to pass to the balancer class constructor.
24 If N is not provided, it defaults to 1.
25 """
26 kwargs["N"] = kwargs.get("N", 1)
27 self.cls = cls
28 self.balancer = cls(**kwargs)
30 def assign_next(self, x: np.ndarray) -> np.ndarray:
31 """Assign treatment to the next point with automatic expansion
33 If the sample size is exceeded, automatically doubles the sample size
34 and continues from the current state.
36 Parameters
37 ----------
38 x : np.ndarray
39 Covariate profile of unit to assign treatment
41 Returns
42 -------
43 np.ndarray
44 Treatment assignment
45 """
46 try:
47 return self.balancer.assign_next(x)
48 except SampleSizeExpendedError:
49 bal_def = self.balancer.definition
50 bal_state = self.balancer.state
51 bal_def["N"] = bal_def["N"] * 2
52 self.balancer = self.cls(**bal_def)
53 self.balancer.update_state(**bal_state)
54 return self.balancer.assign_next(x)
56 def assign_all(self, X: np.ndarray) -> np.ndarray:
57 """Assign all points
59 This assigns units to treatment in the offline setting in which all covariate
60 profiles are available prior to assignment. The algorithm assigns as if units
61 were still only observed in a stream.
63 Parameters
64 ----------
65 X : np.ndarray
66 Array of size n × d of covariate profiles
68 Returns
69 -------
70 np.ndarray
71 Array of treatment assignments
72 """
73 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
75 @property
76 def definition(self):
77 """Get the definition parameters of the wrapped balancer
79 Returns
80 -------
81 dict
82 Dictionary containing the balancer class and all definition parameters
83 """
84 return {"cls": self.cls, **self.balancer.definition}
86 @property
87 def state(self):
88 """Get the current state of the wrapped balancer
90 Returns
91 -------
92 dict
93 Dictionary containing the current state
94 """
95 return self.balancer.state
97 def update_state(self, **kwargs):
98 """Update the state of the wrapped balancer
100 Parameters
101 ----------
102 **kwargs : dict
103 State parameters to update
104 """
105 self.balancer.update_state(**kwargs)
107 def reset(self):
108 """Reset the wrapped balancer to initial state"""
109 self.balancer.reset()