Skip to content

Commit 48613ea

Browse files
authored
Merge branch 'master' into kebatt/bootstrapFixes
2 parents 50800f1 + 949c246 commit 48613ea

File tree

4 files changed

+21
-16
lines changed

4 files changed

+21
-16
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,12 @@ Such questions arise frequently in customer segmentation (what is the effect of
9191

9292
# News
9393

94-
**March 6, 2020:** Release v0.7.0, see release notes [here](https://github.com/Microsoft/EconML/releases/tag/v0.7.0)
94+
**September 4, 2020:** Release v0.8.0b1, see release notes [here](https://github.com/Microsoft/EconML/releases/tag/v0.8.0b1)
9595

9696
<details><summary>Previous releases</summary>
9797

98+
**March 6, 2020:** Release v0.7.0, see release notes [here](https://github.com/Microsoft/EconML/releases/tag/v0.7.0)
99+
98100
**February 18, 2020:** Release v0.7.0b1, see release notes [here](https://github.com/Microsoft/EconML/releases/tag/v0.7.0b1)
99101

100102
**January 10, 2020:** Release v0.6.1, see release notes [here](https://github.com/Microsoft/EconML/releases/tag/v0.6.1)

econml/dml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from sklearn.utils import check_random_state
5151
from .cate_estimator import (BaseCateEstimator, LinearCateEstimator,
5252
TreatmentExpansionMixin, StatsModelsCateEstimatorMixin,
53-
DebiasedLassoCateEstimatorMixin)
53+
LinearModelFinalCateEstimatorMixin, DebiasedLassoCateEstimatorMixin)
5454
from .inference import StatsModelsInference, GenericSingleTreatmentModelFinalInference
5555
from ._rlearner import _RLearner
5656
from .sklearn_extensions.model_selection import WeightedStratifiedKFold
@@ -296,7 +296,7 @@ def cate_feature_names(self, input_feature_names=None):
296296
raise AttributeError("Featurizer does not have a method: get_feature_names!")
297297

298298

299-
class DMLCateEstimator(_BaseDMLCateEstimator):
299+
class DMLCateEstimator(LinearModelFinalCateEstimatorMixin, _BaseDMLCateEstimator):
300300
"""
301301
The base class for parametric Double ML estimators. The estimator is a special
302302
case of an :class:`._RLearner` estimator, which in turn is a special case

econml/tests/test_dml.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,16 @@ def make_random(n, is_discrete, d):
106106
all_infs = [None, 'statsmodels', BootstrapInference(1)]
107107

108108
for est, multi, infs in\
109-
[(LinearDMLCateEstimator(model_y=Lasso(),
109+
[(DMLCateEstimator(model_y=Lasso(),
110+
model_t=model_t,
111+
model_final=Lasso(alpha=0.1, fit_intercept=False),
112+
featurizer=featurizer,
113+
fit_cate_intercept=fit_cate_intercept,
114+
discrete_treatment=is_discrete),
115+
True,
116+
[None] +
117+
([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])),
118+
(LinearDMLCateEstimator(model_y=Lasso(),
110119
model_t='auto',
111120
featurizer=featurizer,
112121
fit_cate_intercept=fit_cate_intercept,
@@ -167,8 +176,7 @@ def make_random(n, is_discrete, d):
167176
eff = est.effect(X, T0=T0, T1=T)
168177
self.assertEqual(shape(eff), effect_shape)
169178

170-
if isinstance(est, LinearDMLCateEstimator) or\
171-
isinstance(est, SparseLinearDMLCateEstimator):
179+
if not isinstance(est, KernelDMLCateEstimator):
172180
self.assertEqual(shape(est.coef_), coef_shape)
173181
if fit_cate_intercept:
174182
self.assertEqual(shape(est.intercept_), intercept_shape)
@@ -185,10 +193,7 @@ def make_random(n, is_discrete, d):
185193
(2,) + const_marginal_effect_shape)
186194
self.assertEqual(shape(est.effect_interval(X, T0=T0, T1=T)),
187195
(2,) + effect_shape)
188-
if (isinstance(est,
189-
LinearDMLCateEstimator) or
190-
isinstance(est,
191-
SparseLinearDMLCateEstimator)):
196+
if not isinstance(est, KernelDMLCateEstimator):
192197
self.assertEqual(shape(est.coef__interval()),
193198
(2,) + coef_shape)
194199
if fit_cate_intercept:
@@ -263,10 +268,7 @@ def make_random(n, is_discrete, d):
263268
marg_effect_inf.population_summary()._repr_html_()
264269

265270
# test coef__inference and intercept__inference
266-
if (isinstance(est,
267-
LinearDMLCateEstimator) or
268-
isinstance(est,
269-
SparseLinearDMLCateEstimator)):
271+
if not isinstance(est, KernelDMLCateEstimator):
270272
if X is None:
271273
cm = pytest.raises(AttributeError)
272274
else:

setup.cfg

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ max-line-length=119
77
convention=numpy
88

99
[build_sphinx]
10-
version = 0.7.0
10+
version = 0.8.0b1
1111

1212
[metadata]
1313
name = econml
14-
version = 0.7.0
14+
version = 0.8.0b1
1515
author = Microsoft Corporation
1616
description = This package contains several methods for calculating Conditional Average Treatment Effects
1717
long_description = file: README.md
@@ -59,6 +59,7 @@ tests_require =
5959
pytest-cov
6060
jupyter
6161
nbconvert < 6
62+
traitlets < 5; python_version <= '3.5'
6263
seaborn
6364
lightgbm
6465
dowhy

0 commit comments

Comments
 (0)