Skip to content

Commit 949c246

Browse files
vsyrgkanisvasilismsrkbattocchi
authored
Vasilis/dmlcoef (#283)
Added coef_ and intercept_ to DMLCateEstimator by inheriting LinearModelFinalCateEstimatorMixin Co-authored-by: Vasilis <[email protected]> Co-authored-by: Keith Battocchi <[email protected]>
1 parent ffdfda8 commit 949c246

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

econml/dml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from sklearn.utils import check_random_state
5151
from .cate_estimator import (BaseCateEstimator, LinearCateEstimator,
5252
TreatmentExpansionMixin, StatsModelsCateEstimatorMixin,
53-
DebiasedLassoCateEstimatorMixin)
53+
LinearModelFinalCateEstimatorMixin, DebiasedLassoCateEstimatorMixin)
5454
from .inference import StatsModelsInference, GenericSingleTreatmentModelFinalInference
5555
from ._rlearner import _RLearner
5656
from .sklearn_extensions.model_selection import WeightedStratifiedKFold
@@ -296,7 +296,7 @@ def cate_feature_names(self, input_feature_names=None):
296296
raise AttributeError("Featurizer does not have a method: get_feature_names!")
297297

298298

299-
class DMLCateEstimator(_BaseDMLCateEstimator):
299+
class DMLCateEstimator(LinearModelFinalCateEstimatorMixin, _BaseDMLCateEstimator):
300300
"""
301301
The base class for parametric Double ML estimators. The estimator is a special
302302
case of an :class:`._RLearner` estimator, which in turn is a special case

econml/tests/test_dml.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,16 @@ def make_random(n, is_discrete, d):
110110
all_infs.append(BootstrapInference(1))
111111

112112
for est, multi, infs in\
113-
[(LinearDMLCateEstimator(model_y=Lasso(),
113+
[(DMLCateEstimator(model_y=Lasso(),
114+
model_t=model_t,
115+
model_final=Lasso(alpha=0.1, fit_intercept=False),
116+
featurizer=featurizer,
117+
fit_cate_intercept=fit_cate_intercept,
118+
discrete_treatment=is_discrete),
119+
True,
120+
[None] +
121+
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
122+
(LinearDMLCateEstimator(model_y=Lasso(),
114123
model_t='auto',
115124
featurizer=featurizer,
116125
fit_cate_intercept=fit_cate_intercept,
@@ -171,8 +180,7 @@ def make_random(n, is_discrete, d):
171180
eff = est.effect(X, T0=T0, T1=T)
172181
self.assertEqual(shape(eff), effect_shape)
173182

174-
if isinstance(est, LinearDMLCateEstimator) or\
175-
isinstance(est, SparseLinearDMLCateEstimator):
183+
if not isinstance(est, KernelDMLCateEstimator):
176184
self.assertEqual(shape(est.coef_), coef_shape)
177185
if fit_cate_intercept:
178186
self.assertEqual(shape(est.intercept_), intercept_shape)
@@ -189,10 +197,7 @@ def make_random(n, is_discrete, d):
189197
(2,) + const_marginal_effect_shape)
190198
self.assertEqual(shape(est.effect_interval(X, T0=T0, T1=T)),
191199
(2,) + effect_shape)
192-
if (isinstance(est,
193-
LinearDMLCateEstimator) or
194-
isinstance(est,
195-
SparseLinearDMLCateEstimator)):
200+
if not isinstance(est, KernelDMLCateEstimator):
196201
self.assertEqual(shape(est.coef__interval()),
197202
(2,) + coef_shape)
198203
if fit_cate_intercept:
@@ -267,10 +272,7 @@ def make_random(n, is_discrete, d):
267272
marg_effect_inf.population_summary()._repr_html_()
268273

269274
# test coef__inference and intercept__inference
270-
if (isinstance(est,
271-
LinearDMLCateEstimator) or
272-
isinstance(est,
273-
SparseLinearDMLCateEstimator)):
275+
if not isinstance(est, KernelDMLCateEstimator):
274276
if X is None:
275277
cm = pytest.raises(AttributeError)
276278
else:

0 commit comments

Comments
 (0)