@@ -110,7 +110,16 @@ def make_random(n, is_discrete, d):
110110 all_infs .append (BootstrapInference (1 ))
111111
112112 for est , multi , infs in \
113- [(LinearDMLCateEstimator (model_y = Lasso (),
113+ [(DMLCateEstimator (model_y = Lasso (),
114+ model_t = model_t ,
115+ model_final = Lasso (alpha = 0.1 , fit_intercept = False ),
116+ featurizer = featurizer ,
117+ fit_cate_intercept = fit_cate_intercept ,
118+ discrete_treatment = is_discrete ),
119+ True ,
120+ [None ] +
121+ ([BootstrapInference (n_bootstrap_samples = 20 )] if not is_discrete else [])),
122+ (LinearDMLCateEstimator (model_y = Lasso (),
114123 model_t = 'auto' ,
115124 featurizer = featurizer ,
116125 fit_cate_intercept = fit_cate_intercept ,
@@ -171,8 +180,7 @@ def make_random(n, is_discrete, d):
171180 eff = est .effect (X , T0 = T0 , T1 = T )
172181 self .assertEqual (shape (eff ), effect_shape )
173182
174- if isinstance (est , LinearDMLCateEstimator ) or \
175- isinstance (est , SparseLinearDMLCateEstimator ):
183+ if not isinstance (est , KernelDMLCateEstimator ):
176184 self .assertEqual (shape (est .coef_ ), coef_shape )
177185 if fit_cate_intercept :
178186 self .assertEqual (shape (est .intercept_ ), intercept_shape )
@@ -189,10 +197,7 @@ def make_random(n, is_discrete, d):
189197 (2 ,) + const_marginal_effect_shape )
190198 self .assertEqual (shape (est .effect_interval (X , T0 = T0 , T1 = T )),
191199 (2 ,) + effect_shape )
192- if (isinstance (est ,
193- LinearDMLCateEstimator ) or
194- isinstance (est ,
195- SparseLinearDMLCateEstimator )):
200+ if not isinstance (est , KernelDMLCateEstimator ):
196201 self .assertEqual (shape (est .coef__interval ()),
197202 (2 ,) + coef_shape )
198203 if fit_cate_intercept :
@@ -267,10 +272,7 @@ def make_random(n, is_discrete, d):
267272 marg_effect_inf .population_summary ()._repr_html_ ()
268273
269274 # test coef__inference and intercept__inference
270- if (isinstance (est ,
271- LinearDMLCateEstimator ) or
272- isinstance (est ,
273- SparseLinearDMLCateEstimator )):
275+ if not isinstance (est , KernelDMLCateEstimator ):
274276 if X is None :
275277 cm = pytest .raises (AttributeError )
276278 else :
0 commit comments