Skip to content

Commit 223096b

Browse files
committed
Enable expanded inference interface for bootstrap
1 parent 01e52ff commit 223096b

File tree

4 files changed

+80
-37
lines changed

4 files changed

+80
-37
lines changed

econml/bootstrap.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,31 @@ def call(lower=5, upper=95):
176176
return call
177177

178178
def get_inference():
179-
raise NotImplementedError("The {0} method is not yet supported by bootstrap inference; "
180-
"consider using a different inference method if available.".format(name))
179+
# can't import from econml.inference at top level without creating mutual dependencies
180+
from .inference import InferenceResults
181+
# TODO: consider treating percentile bootstrap differently since we can work directly with
182+
# the empirical distribution
183+
prefix = name[: - len("_inference")]
184+
if prefix in ['const_marginal_effect', 'effect']:
185+
inf_type = 'effect'
186+
elif prefix == 'coef_':
187+
inf_type = 'coefficient'
188+
elif prefix == 'intercept_':
189+
inf_type = 'intercept'
190+
else:
191+
raise AttributeError("Unsupported inference: " + name)
192+
193+
def get_inference():
194+
pred = getattr(self._wrapped, prefix)
195+
stderr = getattr(self, prefix + '_std')
196+
d_t = self._wrapped._d_t[0] if self._wrapped._d_t else 1
197+
d_t = 1 if prefix == 'effect' else d_t
198+
d_y = self._wrapped._d_y[0] if self._wrapped._d_y else 1
199+
return InferenceResults(d_t=d_t, d_y=d_y, pred=pred,
200+
pred_stderr=stderr, inf_type=inf_type,
201+
pred_dist=None, fname_transformer=None)
202+
203+
return get_inference
181204

182205
caught = None
183206
m = None
@@ -202,11 +225,6 @@ def get_inference():
202225
return m()
203226
except AttributeError as err:
204227
caught = err
205-
if name.endswith("_inference"):
206-
try:
207-
return get_inference()
208-
except AttributeError as err:
209-
caught = err
210228
if self._compute_means:
211229
return get_mean()
212230

econml/inference.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,50 @@ def fit(self, estimator, *args, **kwargs):
6262
bootstrap_type=self._bootstrap_type)
6363
est.fit(*args, **kwargs)
6464
self._est = est
65+
self._d_t = estimator._d_t
66+
self._d_y = estimator._d_y
67+
self.d_t = self._d_t[0] if self._d_t else 1
68+
self.d_y = self._d_y[0] if self._d_y else 1
6569

6670
def __getattr__(self, name):
6771
if name.startswith('__'):
6872
raise AttributeError()
6973

7074
m = getattr(self._est, name)
75+
if name.endswith('_interval'): # convert alpha to lower/upper
76+
def wrapped(*args, alpha=0.1, **kwargs):
77+
return m(*args, lower=100 * alpha / 2, upper=100 * (1 - alpha / 2), **kwargs)
78+
return wrapped
79+
else:
80+
return m
7181

72-
def wrapped(*args, alpha=0.1, **kwargs):
73-
return m(*args, lower=100 * alpha / 2, upper=100 * (1 - alpha / 2), **kwargs)
74-
return wrapped
82+
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
83+
smry = Summary()
84+
try:
85+
coef_table = self.coef__inference().summary_frame(alpha=alpha,
86+
value=value, decimals=decimals, feat_name=feat_name)
87+
coef_array = coef_table.values
88+
coef_headers = [i + '\n' +
89+
j for (i, j) in coef_table.columns] if self.d_t > 1 else coef_table.columns.tolist()
90+
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if self.d_y > 1 else coef_table.index.tolist()
91+
coef_title = 'Coefficient Results'
92+
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
93+
except Exception as e:
94+
print("Coefficient Results: ", str(e))
95+
try:
96+
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
97+
value=value, decimals=decimals, feat_name=None)
98+
intercept_array = intercept_table.values
99+
intercept_headers = [i + '\n' + j for (i, j)
100+
in intercept_table.columns] if self.d_t > 1 else intercept_table.columns.tolist()
101+
intercept_stubs = [i + ' | ' + j for (i, j)
102+
in intercept_table.index] if self.d_y > 1 else intercept_table.index.tolist()
103+
intercept_title = 'Intercept Results'
104+
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
105+
except Exception as e:
106+
print("Intercept Results: ", str(e))
107+
if len(smry.tables) > 0:
108+
return smry
75109

76110

77111
class GenericModelFinalInference(Inference):

econml/tests/test_bootstrap.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -295,15 +295,3 @@ def test_stratify_orthoiv(self):
295295
inference = BootstrapInference(n_bootstrap_samples=20)
296296
est.fit(Y, T, Z, X=X, inference=inference)
297297
est.const_marginal_effect_interval(X)
298-
299-
def test_inference_throws_helpful_error(self):
300-
"""Test that we see that inference methods are not yet implemented"""
301-
T = np.random.normal(size=(1000, 1))
302-
Y = T + np.random.normal(size=(1000, 1))
303-
304-
opts = BootstrapInference(5, 2)
305-
306-
est = LinearDMLCateEstimator().fit(Y, T, inference=opts)
307-
308-
with self.assertRaises(NotImplementedError):
309-
eff = est.const_marginal_effect_inference()

econml/tests/test_inference.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.base import clone
77
from sklearn.preprocessing import PolynomialFeatures
88
from econml.dml import LinearDMLCateEstimator
9+
from econml.inference import BootstrapInference
910

1011

1112
class TestInference(unittest.TestCase):
@@ -26,21 +27,23 @@ def setUpClass(cls):
2627
def test_inference_results(self):
2728
"""Tests the inference results summary."""
2829
# Test inference results when `cate_feature_names` doesn not exist
29-
cate_est = LinearDMLCateEstimator(
30-
featurizer=PolynomialFeatures(degree=1,
31-
include_bias=False)
32-
)
33-
wrapped_est = self._NoFeatNamesEst(cate_est)
34-
wrapped_est.fit(
35-
TestInference.Y,
36-
TestInference.T,
37-
TestInference.X,
38-
TestInference.W,
39-
inference='statsmodels'
40-
)
41-
summary_results = wrapped_est.summary()
42-
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
43-
np.testing.assert_array_equal(coef_rows, ['X{}'.format(i) for i in range(TestInference.d_x)])
30+
31+
for inference in [BootstrapInference(n_bootstrap_samples=5), 'statsmodels']:
32+
cate_est = LinearDMLCateEstimator(
33+
featurizer=PolynomialFeatures(degree=1,
34+
include_bias=False)
35+
)
36+
wrapped_est = self._NoFeatNamesEst(cate_est)
37+
wrapped_est.fit(
38+
TestInference.Y,
39+
TestInference.T,
40+
TestInference.X,
41+
TestInference.W,
42+
inference=inference
43+
)
44+
summary_results = wrapped_est.summary()
45+
coef_rows = np.asarray(summary_results.tables[0].data)[1:, 0]
46+
np.testing.assert_array_equal(coef_rows, ['X{}'.format(i) for i in range(TestInference.d_x)])
4447

4548
class _NoFeatNamesEst:
4649
def __init__(self, cate_est):

0 commit comments

Comments
 (0)