@@ -106,7 +106,16 @@ def make_random(n, is_discrete, d):
106106 all_infs = [None , 'statsmodels' , BootstrapInference (1 )]
107107
108108 for est , multi , infs in \
109- [(LinearDMLCateEstimator (model_y = Lasso (),
109+ [(DMLCateEstimator (model_y = Lasso (),
110+ model_t = model_t ,
111+ model_final = Lasso (alpha = 0.1 , fit_intercept = False ),
112+ featurizer = featurizer ,
113+ fit_cate_intercept = fit_cate_intercept ,
114+ discrete_treatment = is_discrete ),
115+ True ,
116+ [None ] +
117+ ([BootstrapInference (n_bootstrap_samples = 20 )] if not is_discrete else [])),
118+ (LinearDMLCateEstimator (model_y = Lasso (),
110119 model_t = 'auto' ,
111120 featurizer = featurizer ,
112121 fit_cate_intercept = fit_cate_intercept ,
@@ -167,8 +176,7 @@ def make_random(n, is_discrete, d):
167176 eff = est .effect (X , T0 = T0 , T1 = T )
168177 self .assertEqual (shape (eff ), effect_shape )
169178
170- if isinstance (est , LinearDMLCateEstimator ) or \
171- isinstance (est , SparseLinearDMLCateEstimator ):
179+ if not isinstance (est , KernelDMLCateEstimator ):
172180 self .assertEqual (shape (est .coef_ ), coef_shape )
173181 if fit_cate_intercept :
174182 self .assertEqual (shape (est .intercept_ ), intercept_shape )
@@ -185,10 +193,7 @@ def make_random(n, is_discrete, d):
185193 (2 ,) + const_marginal_effect_shape )
186194 self .assertEqual (shape (est .effect_interval (X , T0 = T0 , T1 = T )),
187195 (2 ,) + effect_shape )
188- if (isinstance (est ,
189- LinearDMLCateEstimator ) or
190- isinstance (est ,
191- SparseLinearDMLCateEstimator )):
196+ if not isinstance (est , KernelDMLCateEstimator ):
192197 self .assertEqual (shape (est .coef__interval ()),
193198 (2 ,) + coef_shape )
194199 if fit_cate_intercept :
@@ -263,10 +268,7 @@ def make_random(n, is_discrete, d):
263268 marg_effect_inf .population_summary ()._repr_html_ ()
264269
265270 # test coef__inference and intercept__inference
266- if (isinstance (est ,
267- LinearDMLCateEstimator ) or
268- isinstance (est ,
269- SparseLinearDMLCateEstimator )):
271+ if not isinstance (est , KernelDMLCateEstimator ):
270272 if X is None :
271273 cm = pytest .raises (AttributeError )
272274 else :
0 commit comments