Skip to content

Commit 50800f1

Browse files
committed
Implement more feedback
1 parent 63daf39 commit 50800f1

File tree

2 files changed

+37
-65
lines changed

2 files changed

+37
-65
lines changed

econml/bootstrap.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ class BootstrapEstimator:
4848
that should be preferred (meaning this wrapper will compute the mean of it).
4949
This option only affects behavior if `compute_means` is set to ``True``.
5050
51-
bootstrap_type: 'percentile', 'pivot', or 'normal', default 'percentile'
51+
bootstrap_type: 'percentile', 'pivot', or 'normal', default 'pivot'
5252
Bootstrap method used to compute results. 'percentile' will result in using the empiracal CDF of
5353
the replicated computations of the statistics. 'pivot' will also use the replicates but create a pivot
5454
interval that also relies on the estimate over the entire dataset. 'normal' will instead compute an interval
5555
assuming the replicates are normally distributed.
5656
"""
5757

5858
def __init__(self, wrapped, n_bootstrap_samples=1000, n_jobs=None, compute_means=True, prefer_wrapped=False,
59-
bootstrap_type='percentile'):
59+
bootstrap_type='pivot'):
6060
self._instances = [clone(wrapped, safe=False) for _ in range(n_bootstrap_samples)]
6161
self._n_bootstrap_samples = n_bootstrap_samples
6262
self._n_jobs = n_jobs
@@ -162,8 +162,7 @@ def percentile_bootstrap(arr, _):
162162
def pivot_bootstrap(arr, est):
163163
return 2 * est - np.percentile(arr, upper, axis=0), 2 * est - np.percentile(arr, lower, axis=0)
164164

165-
def normal_bootstrap(arr, _):
166-
est = np.mean(arr, axis=0)
165+
def normal_bootstrap(arr, est):
167166
std = np.std(arr, axis=0)
168167
return est - norm.ppf(upper / 100) * std, est - norm.ppf(lower / 100) * std
169168

econml/inference.py

Lines changed: 34 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,37 @@ def fit(self, estimator, *args, **kwargs):
3232
pass
3333

3434

35-
class BootstrapInference(Inference):
35+
class _SummaryMixin:
36+
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
37+
smry = Summary()
38+
try:
39+
coef_table = self.coef__inference().summary_frame(alpha=alpha,
40+
value=value, decimals=decimals, feat_name=feat_name)
41+
coef_array = coef_table.values
42+
coef_headers = [i + '\n' +
43+
j for (i, j) in coef_table.columns] if self.d_t > 1 else coef_table.columns.tolist()
44+
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if self.d_y > 1 else coef_table.index.tolist()
45+
coef_title = 'Coefficient Results'
46+
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
47+
except Exception as e:
48+
print("Coefficient Results: ", str(e))
49+
try:
50+
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
51+
value=value, decimals=decimals, feat_name=None)
52+
intercept_array = intercept_table.values
53+
intercept_headers = [i + '\n' + j for (i, j)
54+
in intercept_table.columns] if self.d_t > 1 else intercept_table.columns.tolist()
55+
intercept_stubs = [i + ' | ' + j for (i, j)
56+
in intercept_table.index] if self.d_y > 1 else intercept_table.index.tolist()
57+
intercept_title = 'Intercept Results'
58+
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
59+
except Exception as e:
60+
print("Intercept Results: ", str(e))
61+
if len(smry.tables) > 0:
62+
return smry
63+
64+
65+
class BootstrapInference(_SummaryMixin, Inference):
3666
"""
3767
Inference instance to perform bootstrapping.
3868
@@ -46,21 +76,20 @@ class BootstrapInference(Inference):
4676
n_jobs: int, optional (default -1)
4777
The maximum number of concurrently running jobs, as in joblib.Parallel.
4878
49-
bootstrap_type: 'percentile', 'pivot', or 'normal', default 'percentile'
79+
bootstrap_type: 'percentile', 'pivot', or 'normal', default 'pivot'
5080
Bootstrap method used to compute results.
5181
'percentile' will result in using the empiracal CDF of the replicated computations of the statistics.
5282
'pivot' will also use the replicates but create a pivot interval that also relies on the estimate
5383
over the entire dataset.
5484
'normal' will instead compute a pivot interval assuming the replicates are normally distributed.
5585
"""
5686

57-
def __init__(self, n_bootstrap_samples=100, n_jobs=-1, bootstrap_type='percentile'):
87+
def __init__(self, n_bootstrap_samples=100, n_jobs=-1, bootstrap_type='pivot'):
5888
self._n_bootstrap_samples = n_bootstrap_samples
5989
self._n_jobs = n_jobs
6090
self._bootstrap_type = bootstrap_type
6191

6292
def fit(self, estimator, *args, **kwargs):
63-
discrete_treatment = estimator._discrete_treatment if hasattr(estimator, '_discrete_treatment') else False
6493
est = BootstrapEstimator(estimator, self._n_bootstrap_samples, self._n_jobs, compute_means=False,
6594
bootstrap_type=self._bootstrap_type)
6695
est.fit(*args, **kwargs)
@@ -82,34 +111,6 @@ def wrapped(*args, alpha=0.1, **kwargs):
82111
else:
83112
return m
84113

85-
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
86-
smry = Summary()
87-
try:
88-
coef_table = self.coef__inference().summary_frame(alpha=alpha,
89-
value=value, decimals=decimals, feat_name=feat_name)
90-
coef_array = coef_table.values
91-
coef_headers = [i + '\n' +
92-
j for (i, j) in coef_table.columns] if self.d_t > 1 else coef_table.columns.tolist()
93-
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if self.d_y > 1 else coef_table.index.tolist()
94-
coef_title = 'Coefficient Results'
95-
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
96-
except Exception as e:
97-
print("Coefficient Results: ", str(e))
98-
try:
99-
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
100-
value=value, decimals=decimals, feat_name=None)
101-
intercept_array = intercept_table.values
102-
intercept_headers = [i + '\n' + j for (i, j)
103-
in intercept_table.columns] if self.d_t > 1 else intercept_table.columns.tolist()
104-
intercept_stubs = [i + ' | ' + j for (i, j)
105-
in intercept_table.index] if self.d_y > 1 else intercept_table.index.tolist()
106-
intercept_title = 'Intercept Results'
107-
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
108-
except Exception as e:
109-
print("Intercept Results: ", str(e))
110-
if len(smry.tables) > 0:
111-
return smry
112-
113114

114115
class GenericModelFinalInference(Inference):
115116
"""
@@ -215,7 +216,7 @@ def effect_inference(self, X, *, T0, T1):
215216
pred_stderr=e_stderr, inf_type='effect', fname_transformer=None)
216217

217218

218-
class LinearModelFinalInference(GenericModelFinalInference):
219+
class LinearModelFinalInference(_SummaryMixin, GenericModelFinalInference):
219220
"""
220221
Inference based on predict_interval of the model_final model. Assumes that estimator
221222
class has a model_final method and that model is linear. Thus, the predict(cross_product(X, T1 - T0)) gives
@@ -318,34 +319,6 @@ def intercept__inference(self):
318319
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=intercept, pred_stderr=intercept_stderr,
319320
inf_type='intercept', fname_transformer=None)
320321

321-
def summary(self, alpha=0.1, value=0, decimals=3, feat_name=None):
322-
smry = Summary()
323-
try:
324-
coef_table = self.coef__inference().summary_frame(alpha=alpha,
325-
value=value, decimals=decimals, feat_name=feat_name)
326-
coef_array = coef_table.values
327-
coef_headers = [i + '\n' +
328-
j for (i, j) in coef_table.columns] if self.d_t > 1 else coef_table.columns.tolist()
329-
coef_stubs = [i + ' | ' + j for (i, j) in coef_table.index] if self.d_y > 1 else coef_table.index.tolist()
330-
coef_title = 'Coefficient Results'
331-
smry.add_table(coef_array, coef_headers, coef_stubs, coef_title)
332-
except Exception as e:
333-
print("Coefficient Results: ", str(e))
334-
try:
335-
intercept_table = self.intercept__inference().summary_frame(alpha=alpha,
336-
value=value, decimals=decimals, feat_name=None)
337-
intercept_array = intercept_table.values
338-
intercept_headers = [i + '\n' + j for (i, j)
339-
in intercept_table.columns] if self.d_t > 1 else intercept_table.columns.tolist()
340-
intercept_stubs = [i + ' | ' + j for (i, j)
341-
in intercept_table.index] if self.d_y > 1 else intercept_table.index.tolist()
342-
intercept_title = 'Intercept Results'
343-
smry.add_table(intercept_array, intercept_headers, intercept_stubs, intercept_title)
344-
except Exception as e:
345-
print("Intercept Results: ", str(e))
346-
if len(smry.tables) > 0:
347-
return smry
348-
349322

350323
class StatsModelsInference(LinearModelFinalInference):
351324
"""Stores statsmodels covariance options.

0 commit comments

Comments
 (0)