@@ -44,32 +44,62 @@ class BootstrapEstimator:
4444 In case a method ending in '_interval' exists on the wrapped object, whether
4545 that should be preferred (meaning this wrapper will compute the mean of it).
4646 This option only affects behavior if `compute_means` is set to ``True``.
47+
48+ stratify_treatment: bool, default False
49+ Whether to stratify by treatment when calling fit; this will ensure that each stratum of treatment
50+ is subsampled independently, so that each resample will have the same number of entries with each
51+ treatment as the original sample did.
4752 """
4853
49- def __init__ (self , wrapped , n_bootstrap_samples = 1000 , n_jobs = None , compute_means = True , prefer_wrapped = False ):
54+ def __init__ (self , wrapped , n_bootstrap_samples = 1000 , n_jobs = None ,
55+ compute_means = True , prefer_wrapped = False , stratify_treatment = False ):
5056 self ._instances = [clone (wrapped , safe = False ) for _ in range (n_bootstrap_samples )]
5157 self ._n_bootstrap_samples = n_bootstrap_samples
5258 self ._n_jobs = n_jobs
5359 self ._compute_means = compute_means
5460 self ._prefer_wrapped = prefer_wrapped
61+ self ._stratify_treatment = stratify_treatment
5562
5663 # TODO: Add a __dir__ implementation?
5764
65+ def _stratified_indices (self , Y , T , * args , ** kwargs ):
66+ assert 1 <= np .ndim (T ) <= 2
67+ unique = np .unique (T , axis = 0 )
68+ indices = []
69+ for el in unique :
70+ ind , = np .where (np .all (T == el , axis = 1 ) if np .ndim (T ) == 2 else T == el )
71+ indices .append (ind )
72+ return indices
73+
5874 def fit (self , * args , ** named_args ):
5975 """
6076 Fit the model.
6177
6278 The full signature of this method is the same as that of the wrapped object's `fit` method.
6379 """
64- n_samples = np .shape (args [0 ] if args else named_args [(* named_args ,)[0 ]])[0 ]
65- indices = np .random .choice (n_samples , size = (self ._n_bootstrap_samples , n_samples ), replace = True )
80+
81+ if self ._stratify_treatment :
82+ index_chunks = self ._stratified_indices (* args , ** named_args )
83+ else :
84+ n_samples = np .shape (args [0 ] if args else named_args [(* named_args ,)[0 ]])[0 ]
85+ index_chunks = [np .arange (n_samples )] # one chunk with all indices
86+
87+ indices = []
88+ for chunk in index_chunks :
89+ n_samples = len (chunk )
90+ indices .append (chunk [np .random .choice (n_samples ,
91+ size = (self ._n_bootstrap_samples , n_samples ),
92+ replace = True )])
93+
94+ indices = np .hstack (indices )
6695
6796 def fit (x , * args , ** kwargs ):
6897 x .fit (* args , ** kwargs )
6998 return x # Explicitly return x in case fit fails to return its target
7099
71100 def convertArg (arg , inds ):
72- return arg [inds ] if arg is not None else None
101+ return np .asarray (arg )[inds ] if arg is not None else None
102+
73103 self ._instances = Parallel (n_jobs = self ._n_jobs , prefer = 'threads' , verbose = 3 )(
74104 delayed (fit )(obj ,
75105 * [convertArg (arg , inds ) for arg in args ],
@@ -84,6 +114,11 @@ def __getattr__(self, name):
84114
85115 Additionally, the suffix "_interval" is supported for getting an interval instead of a point estimate.
86116 """
117+
118+ # don't proxy special methods
119+ if name .startswith ('__' ):
120+ raise AttributeError (name )
121+
87122 def proxy (make_call , name , summary ):
88123 def summarize_with (f ):
89124 return summary (np .array (Parallel (n_jobs = self ._n_jobs , prefer = 'threads' , verbose = 3 )(
0 commit comments