Skip to content

Commit 715adbe

Browse files
committed
Provide helpful error on missing inference methods
1 parent f45f454 commit 715adbe

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

econml/bootstrap.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ def call(lower=5, upper=95):
154154
return call_with_bounds(can_call, lower, upper)
155155
return call
156156

157+
def get_inference():
158+
raise NotImplementedError("The {0} method is not yet supported by bootstrap inference; "
159+
"consider using a different inference method if available.".format(name))
160+
157161
caught = None
158162
if self._compute_means and self._prefer_wrapped:
159163
try:
@@ -162,13 +166,20 @@ def call(lower=5, upper=95):
162166
caught = err
163167
if name.endswith("_interval"):
164168
return get_interval()
169+
elif name.endswith("_inference"):
170+
return get_inference()
165171
else:
166172
# try to get interval first if appropriate, since we don't prefer a wrapped method with this name
167173
if name.endswith("_interval"):
168174
try:
169175
return get_interval()
170176
except AttributeError as err:
171177
caught = err
178+
if name.endswith("_inference"):
179+
try:
180+
return get_inference()
181+
except AttributeError as err:
182+
caught = err
172183
if self._compute_means:
173184
return get_mean()
174185

econml/tests/test_bootstrap.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,15 @@ def test_stratify_orthoiv(self):
294294
inference = BootstrapInference(n_bootstrap_samples=20)
295295
est.fit(Y, T, Z, X=X, inference=inference)
296296
est.const_marginal_effect_interval(X)
297+
298+
def test_inference_throws_helpful_error(self):
299+
"""Test that we see that inference methods are not yet implemented"""
300+
T = np.random.normal(size=(1000, 1))
301+
Y = T + np.random.normal(size=(1000, 1))
302+
303+
opts = BootstrapInference(5, 2)
304+
305+
est = LinearDMLCateEstimator().fit(Y, T, inference=opts)
306+
307+
with self.assertRaises(NotImplementedError):
308+
eff = est.const_marginal_effect_inference()

0 commit comments

Comments
 (0)