Skip to content

Commit 5b069f0

Browse files
committed
remove deprecated args
1 parent 62c065c commit 5b069f0

File tree

8 files changed

+7
-167
lines changed

8 files changed

+7
-167
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,27 +396,27 @@ See the <a href="#references">References</a> section for more details.
396396
reg = lambda: RandomForestRegressor(min_samples_leaf=20)
397397
clf = lambda: RandomForestClassifier(min_samples_leaf=20)
398398
models = [('ldml', LinearDML(model_y=reg(), model_t=clf(), discrete_treatment=True,
399-
linear_first_stages=False, n_splits=3)),
399+
linear_first_stages=False, cv=3)),
400400
('xlearner', XLearner(models=reg(), cate_models=reg(), propensity_model=clf())),
401401
('dalearner', DomainAdaptationLearner(models=reg(), final_models=reg(), propensity_model=clf())),
402402
('slearner', SLearner(overall_model=reg())),
403403
('drlearner', DRLearner(model_propensity=clf(), model_regression=reg(),
404-
model_final=reg(), n_splits=3)),
404+
model_final=reg(), cv=3)),
405405
('rlearner', NonParamDML(model_y=reg(), model_t=clf(), model_final=reg(),
406-
discrete_treatment=True, n_splits=3)),
406+
discrete_treatment=True, cv=3)),
407407
('dml3dlasso', DML(model_y=reg(), model_t=clf(),
408408
model_final=LassoCV(cv=3, fit_intercept=False),
409409
discrete_treatment=True,
410410
featurizer=PolynomialFeatures(degree=3),
411-
linear_first_stages=False, n_splits=3))
411+
linear_first_stages=False, cv=3))
412412
]
413413

414414
# fit cate models on train data
415415
models = [(name, mdl.fit(Y_train, T_train, X=X_train)) for name, mdl in models]
416416

417417
# score cate models on validation data
418418
scorer = RScorer(model_y=reg(), model_t=clf(),
419-
discrete_treatment=True, n_splits=3, mc_iters=2, mc_agg='median')
419+
discrete_treatment=True, cv=3, mc_iters=2, mc_agg='median')
420420
scorer.fit(Y_val, T_val, X=X_val)
421421
rscore = [scorer.score(mdl) for _, mdl in models]
422422
# select the best model

econml/_ortho_learner.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,8 @@ def _gen_ortho_learner_model_final(self):
428428

429429
def __init__(self, *,
430430
discrete_treatment, discrete_instrument, categories, cv, random_state,
431-
n_splits='raise', mc_iters=None, mc_agg='mean'):
431+
mc_iters=None, mc_agg='mean'):
432432
self.cv = cv
433-
self.n_splits = n_splits
434433
self.discrete_treatment = discrete_treatment
435434
self.discrete_instrument = discrete_instrument
436435
self.random_state = random_state
@@ -855,18 +854,3 @@ def models_nuisance_(self):
855854
if not hasattr(self, '_models_nuisance'):
856855
raise AttributeError("Model is not fitted!")
857856
return self._models_nuisance
858-
859-
#######################################################
860-
# These should be removed once `n_splits` is deprecated
861-
#######################################################
862-
863-
@property
864-
def n_splits(self):
865-
return self.cv
866-
867-
@n_splits.setter
868-
def n_splits(self, value):
869-
if value != 'raise':
870-
warn("Parameter `n_splits` has been deprecated and will be removed in the next version. "
871-
"Use parameter `cv` instead.")
872-
self.cv = value

econml/dml/_rlearner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,11 @@ def _gen_rlearner_model_final(self):
261261
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
262262
"""
263263

264-
def __init__(self, *, discrete_treatment, categories, cv, random_state,
265-
n_splits='raise', mc_iters=None, mc_agg='mean'):
264+
def __init__(self, *, discrete_treatment, categories, cv, random_state, mc_iters=None, mc_agg='mean'):
266265
super().__init__(discrete_treatment=discrete_treatment,
267266
discrete_instrument=False, # no instrument, so doesn't matter
268267
categories=categories,
269268
cv=cv,
270-
n_splits=n_splits,
271269
random_state=random_state,
272270
mc_iters=mc_iters,
273271
mc_agg=mc_agg)

econml/dml/causal_forest.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ def __init__(self, *,
490490
discrete_treatment=False,
491491
categories='auto',
492492
cv=2,
493-
n_crossfit_splits='raise',
494493
mc_iters=None,
495494
mc_agg='mean',
496495
drate=True,
@@ -541,13 +540,9 @@ def __init__(self, *,
541540
self.subforest_size = subforest_size
542541
self.n_jobs = n_jobs
543542
self.verbose = verbose
544-
self.n_crossfit_splits = n_crossfit_splits
545-
if self.n_crossfit_splits != 'raise':
546-
cv = self.n_crossfit_splits
547543
super().__init__(discrete_treatment=discrete_treatment,
548544
categories=categories,
549545
cv=cv,
550-
n_splits=n_crossfit_splits,
551546
mc_iters=mc_iters,
552547
mc_agg=mc_agg,
553548
random_state=random_state)
@@ -971,17 +966,3 @@ def __getitem__(self, index):
971966
def __iter__(self):
972967
"""Return iterator over estimators in the ensemble."""
973968
return self.model_cate.__iter__()
974-
975-
#######################################################
976-
# These should be removed once `n_splits` is deprecated
977-
#######################################################
978-
979-
@property
980-
def n_crossfit_splits(self):
981-
return self.cv
982-
983-
@n_crossfit_splits.setter
984-
def n_crossfit_splits(self, value):
985-
if value != 'raise':
986-
warn("Deprecated by parameter `n_crossfit_splits` and will be removed in next version.")
987-
self.cv = value

econml/dml/dml.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,6 @@ def __init__(self, *,
420420
discrete_treatment=False,
421421
categories='auto',
422422
cv=2,
423-
n_splits='raise',
424423
mc_iters=None,
425424
mc_agg='mean',
426425
random_state=None):
@@ -435,7 +434,6 @@ def __init__(self, *,
435434
super().__init__(discrete_treatment=discrete_treatment,
436435
categories=categories,
437436
cv=cv,
438-
n_splits=n_splits,
439437
mc_iters=mc_iters,
440438
mc_agg=mc_agg,
441439
random_state=random_state)
@@ -596,7 +594,6 @@ def __init__(self, *,
596594
discrete_treatment=False,
597595
categories='auto',
598596
cv=2,
599-
n_splits='raise',
600597
mc_iters=None,
601598
mc_agg='mean',
602599
random_state=None):
@@ -609,7 +606,6 @@ def __init__(self, *,
609606
discrete_treatment=discrete_treatment,
610607
categories=categories,
611608
cv=cv,
612-
n_splits=n_splits,
613609
mc_iters=mc_iters,
614610
mc_agg=mc_agg,
615611
random_state=random_state,)
@@ -790,7 +786,6 @@ def __init__(self, *,
790786
discrete_treatment=False,
791787
categories='auto',
792788
cv=2,
793-
n_splits='raise',
794789
mc_iters=None,
795790
mc_agg='mean',
796791
random_state=None):
@@ -810,7 +805,6 @@ def __init__(self, *,
810805
discrete_treatment=discrete_treatment,
811806
categories=categories,
812807
cv=cv,
813-
n_splits=n_splits,
814808
mc_iters=mc_iters,
815809
mc_agg=mc_agg,
816810
random_state=random_state)
@@ -974,7 +968,6 @@ def __init__(self, model_y='auto', model_t='auto',
974968
dim=20,
975969
bw=1.0,
976970
cv=2,
977-
n_splits='raise',
978971
mc_iters=None, mc_agg='mean',
979972
random_state=None):
980973
self.dim = dim
@@ -987,7 +980,6 @@ def __init__(self, model_y='auto', model_t='auto',
987980
discrete_treatment=discrete_treatment,
988981
categories=categories,
989982
cv=cv,
990-
n_splits=n_splits,
991983
mc_iters=mc_iters,
992984
mc_agg=mc_agg,
993985
random_state=random_state)
@@ -1087,7 +1079,6 @@ def __init__(self, *,
10871079
discrete_treatment=False,
10881080
categories='auto',
10891081
cv=2,
1090-
n_splits='raise',
10911082
mc_iters=None,
10921083
mc_agg='mean',
10931084
random_state=None):
@@ -1101,7 +1092,6 @@ def __init__(self, *,
11011092
super().__init__(discrete_treatment=discrete_treatment,
11021093
categories=categories,
11031094
cv=cv,
1104-
n_splits=n_splits,
11051095
mc_iters=mc_iters,
11061096
mc_agg=mc_agg,
11071097
random_state=random_state)
@@ -1190,7 +1180,6 @@ def ForestDML(model_y, model_t,
11901180
discrete_treatment=False,
11911181
categories='auto',
11921182
cv=2,
1193-
n_crossfit_splits='raise',
11941183
mc_iters=None,
11951184
mc_agg='mean',
11961185
n_estimators=100,
@@ -1245,10 +1234,6 @@ def ForestDML(model_y, model_t,
12451234
Unless an iterable is used, we call `split(concat[W, X], T)` to generate the splits. If all
12461235
W, X are None, then we call `split(ones((T.shape[0], 1)), T)`.
12471236
1248-
n_crossfit_splits: int or 'raise', optional (default='raise')
1249-
Deprecated by parameter `cv` and will be removed in next version. Can be used
1250-
interchangeably with `cv`.
1251-
12521237
mc_iters: int, optional (default=None)
12531238
The number of times to rerun the first stage models to reduce the variance of the nuisances.
12541239
@@ -1375,7 +1360,6 @@ def ForestDML(model_y, model_t,
13751360
discrete_treatment=discrete_treatment,
13761361
categories=categories,
13771362
cv=cv,
1378-
n_crossfit_splits=n_crossfit_splits,
13791363
mc_iters=mc_iters,
13801364
mc_agg=mc_agg,
13811365
n_estimators=n_estimators,

0 commit comments

Comments
 (0)