Skip to content

Commit aeb6948

Browse files
committed
Enable stratification in bootstrap
1 parent 0f5ddfe commit aeb6948

File tree

4 files changed

+59
-11
lines changed

4 files changed

+59
-11
lines changed

econml/bootstrap.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,62 @@ class BootstrapEstimator:
4444
In case a method ending in '_interval' exists on the wrapped object, whether
4545
that should be preferred (meaning this wrapper will compute the mean of it).
4646
This option only affects behavior if `compute_means` is set to ``True``.
47+
48+
stratify_treatment: bool, default False
49+
Whether to stratify by treatment when calling fit; this will ensure that each stratum of treatment
50+
is subsampled independently, so that each resample will have the same number of entries with each
51+
treatment as the original sample did.
4752
"""
4853

49-
def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False):
54+
def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None,
55+
compute_means=True, prefer_wrapped=False, stratify_treatment=False):
5056
self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
5157
self._n_bootstrap_samples = n_bootstrap_samples
5258
self._n_jobs = n_jobs
5359
self._compute_means = compute_means
5460
self._prefer_wrapped = prefer_wrapped
61+
self._stratify_treatment = stratify_treatment
5562

5663
# TODO: Add a __dir__ implementation?
5764

65+
def _stratified_indices(self, Y, T, *args, **kwargs):
66+
assert 1 <= np.ndim(T) <= 2
67+
unique = np.unique(T, axis=0)
68+
indices = []
69+
for el in unique:
70+
ind, = np.where(np.all(T == el, axis=1) if np.ndim(T) == 2 else T == el)
71+
indices.append(ind)
72+
return indices
73+
5874
def fit(self, *args, **named_args):
5975
"""
6076
Fit the model.
6177
6278
The full signature of this method is the same as that of the wrapped object's `fit` method.
6379
"""
64-
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
65-
indices = np.random.choice(n_samples, size=(self._n_bootstrap_samples, n_samples), replace=True)
80+
81+
if self._stratify_treatment:
82+
index_chunks = self._stratified_indices(*args, **named_args)
83+
else:
84+
n_samples = np.shape(args[0] if args else named_args[(*named_args,)[0]])[0]
85+
index_chunks = [np.arange(n_samples)] # one chunk with all indices
86+
87+
indices = []
88+
for chunk in index_chunks:
89+
n_samples = len(chunk)
90+
indices.append(chunk[np.random.choice(n_samples,
91+
size=(self._n_bootstrap_samples, n_samples),
92+
replace=True)])
93+
94+
indices = np.hstack(indices)
6695

6796
def fit(x, *args, **kwargs):
6897
x.fit(*args, **kwargs)
6998
return x # Explicitly return x in case fit fails to return its target
7099

71100
def convertArg(arg, inds):
72-
return arg[inds] if arg is not None else None
101+
return np.asarray(arg)[inds] if arg is not None else None
102+
73103
self._instances = Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(
74104
delayed(fit)(obj,
75105
*[convertArg(arg, inds) for arg in args],
@@ -84,6 +114,11 @@ def __getattr__(self, name):
84114
85115
Additionally, the suffix "_interval" is supported for getting an interval instead of a point estimate.
86116
"""
117+
118+
# don't proxy special methods
119+
if name.startswith('__'):
120+
raise AttributeError(name)
121+
87122
def proxy(make_call, name, summary):
88123
def summarize_with(f):
89124
return summary(np.array(Parallel(n_jobs=self._n_jobs, prefer='threads', verbose=3)(

econml/inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def __init__(self, n_bootstrap_samples=100, n_jobs=-1):
5252
self._n_jobs = n_jobs
5353

5454
def fit(self, estimator, *args, **kwargs):
55-
est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False)
55+
discrete_treatment = estimator._discrete_treatment if hasattr(estimator, '_discrete_treatment') else False
56+
est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False,
57+
stratify_treatment=discrete_treatment)
5658
est.fit(*args, **kwargs)
5759
self._est = est
5860

econml/tests/test_bootstrap.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from econml.inference import BootstrapInference
66
from econml.dml import LinearDMLCateEstimator
77
from econml.two_stage_least_squares import NonparametricTwoStageLeastSquares
8-
from sklearn.linear_model import LinearRegression
8+
from sklearn.linear_model import LinearRegression, LogisticRegression
99
from sklearn.preprocessing import PolynomialFeatures
1010
import numpy as np
1111
import unittest
@@ -265,3 +265,18 @@ def test_internal_options(self):
265265

266266
# TODO: test that the estimated effect is usually within the bounds
267267
# and that the true effect is also usually within the bounds
268+
269+
def test_stratify(self):
270+
"""Test that we can properly stratify by treatment"""
271+
T = [1, 0, 1, 2, 0, 2]
272+
Y = [1, 2, 3, 4, 5, 6]
273+
X = np.array([1, 1, 2, 2, 1, 2]).reshape(-1, 1)
274+
est = LinearDMLCateEstimator(model_y=LinearRegression(), model_t=LogisticRegression(), discrete_treatment=True)
275+
est.fit(Y, T, inference='bootstrap')
276+
est.const_marginal_effect_interval()
277+
278+
est.fit(Y, T, X=X, inference='bootstrap')
279+
est.const_marginal_effect_interval(X)
280+
281+
est.fit(Y, np.asarray(T).reshape(-1, 1), inference='bootstrap') # test stratifying 2D treatment
282+
est.const_marginal_effect_interval()

econml/tests/test_dml.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,7 @@ def make_random(is_discrete, d):
9797

9898
model_t = LogisticRegression() if is_discrete else Lasso()
9999

100-
# TODO: add stratification to bootstrap so that we can use it
101-
# even with discrete treatments
102-
all_infs = [None, 'statsmodels']
103-
if not is_discrete:
104-
all_infs.append(BootstrapInference(1))
100+
all_infs = [None, 'statsmodels', BootstrapInference(1)]
105101

106102
for est, multi, infs in\
107103
[(LinearDMLCateEstimator(model_y=Lasso(),

0 commit comments

Comments
 (0)