Coverage for src/bwd/online.py: 39%
31 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-08-19 16:45 +0000
1import numpy as np
3from .exceptions import SampleSizeExpendedError
6class Online(object):
7 def __init__(self, cls, **kwargs):
8 kwargs["N"] = kwargs.get("N", 1)
9 self.cls = cls
10 self.balancer = cls(**kwargs)
12 def assign_next(self, x: np.ndarray) -> np.ndarray:
13 try:
14 return self.balancer.assign_next(x)
15 except SampleSizeExpendedError:
16 bal_def = self.balancer.definition
17 bal_state = self.balancer.state
18 bal_def["N"] = bal_def["N"] * 2
19 self.balancer = self.cls(**bal_def)
20 self.balancer.update_state(**bal_state)
21 return self.balancer.assign_next(x)
23 def assign_all(self, X: np.ndarray) -> np.ndarray:
24 """Assign all points
26 This assigns units to treatment in the offline setting in which all covariate
27 profiles are available prior to assignment. The algorithm assigns as if units
28 were still only observed in a stream.
30 Args:
31 X: array of size n × d of covariate profiles
32 """
33 if self.intercept:
34 X = np.hstack((X, np.ones((X.shape[0], 1))))
35 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])])
37 @property
38 def definition(self):
39 return {"cls": self.cls, **self.balancer.definition}
41 @property
42 def state(self):
43 return self.balancer.state
45 def update_state(self, **kwargs):
46 self.balancer.update_state(**kwargs)
48 def reset(self):
49 self.balancer.reset()