Coverage for src/bwd/online.py: 39%

31 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-08-19 16:45 +0000

1import numpy as np 

2 

3from .exceptions import SampleSizeExpendedError 

4 

5 

6class Online(object): 

7 def __init__(self, cls, **kwargs): 

8 kwargs["N"] = kwargs.get("N", 1) 

9 self.cls = cls 

10 self.balancer = cls(**kwargs) 

11 

12 def assign_next(self, x: np.ndarray) -> np.ndarray: 

13 try: 

14 return self.balancer.assign_next(x) 

15 except SampleSizeExpendedError: 

16 bal_def = self.balancer.definition 

17 bal_state = self.balancer.state 

18 bal_def["N"] = bal_def["N"] * 2 

19 self.balancer = self.cls(**bal_def) 

20 self.balancer.update_state(**bal_state) 

21 return self.balancer.assign_next(x) 

22 

23 def assign_all(self, X: np.ndarray) -> np.ndarray: 

24 """Assign all points 

25 

26 This assigns units to treatment in the offline setting in which all covariate 

27 profiles are available prior to assignment. The algorithm assigns as if units 

28 were still only observed in a stream. 

29 

30 Args: 

31 X: array of size n × d of covariate profiles 

32 """ 

33 if self.intercept: 

34 X = np.hstack((X, np.ones((X.shape[0], 1)))) 

35 return np.array([self.assign_next(X[i, :]) for i in range(X.shape[0])]) 

36 

37 @property 

38 def definition(self): 

39 return {"cls": self.cls, **self.balancer.definition} 

40 

41 @property 

42 def state(self): 

43 return self.balancer.state 

44 

45 def update_state(self, **kwargs): 

46 self.balancer.update_state(**kwargs) 

47 

48 def reset(self): 

49 self.balancer.reset()