@@ -32,7 +32,37 @@ def fit(self, estimator, *args, **kwargs):
3232 pass
3333
3434
35- class BootstrapInference (Inference ):
35+ class _SummaryMixin :
36+ def summary (self , alpha = 0.1 , value = 0 , decimals = 3 , feat_name = None ):
37+ smry = Summary ()
38+ try :
39+ coef_table = self .coef__inference ().summary_frame (alpha = alpha ,
40+ value = value , decimals = decimals , feat_name = feat_name )
41+ coef_array = coef_table .values
42+ coef_headers = [i + '\n ' +
43+ j for (i , j ) in coef_table .columns ] if self .d_t > 1 else coef_table .columns .tolist ()
44+ coef_stubs = [i + ' | ' + j for (i , j ) in coef_table .index ] if self .d_y > 1 else coef_table .index .tolist ()
45+ coef_title = 'Coefficient Results'
46+ smry .add_table (coef_array , coef_headers , coef_stubs , coef_title )
47+ except Exception as e :
48+ print ("Coefficient Results: " , str (e ))
49+ try :
50+ intercept_table = self .intercept__inference ().summary_frame (alpha = alpha ,
51+ value = value , decimals = decimals , feat_name = None )
52+ intercept_array = intercept_table .values
53+ intercept_headers = [i + '\n ' + j for (i , j )
54+ in intercept_table .columns ] if self .d_t > 1 else intercept_table .columns .tolist ()
55+ intercept_stubs = [i + ' | ' + j for (i , j )
56+ in intercept_table .index ] if self .d_y > 1 else intercept_table .index .tolist ()
57+ intercept_title = 'Intercept Results'
58+ smry .add_table (intercept_array , intercept_headers , intercept_stubs , intercept_title )
59+ except Exception as e :
60+ print ("Intercept Results: " , str (e ))
61+ if len (smry .tables ) > 0 :
62+ return smry
63+
64+
65+ class BootstrapInference (_SummaryMixin , Inference ):
3666 """
3767 Inference instance to perform bootstrapping.
3868
@@ -46,21 +76,20 @@ class BootstrapInference(Inference):
4676 n_jobs: int, optional (default -1)
4777 The maximum number of concurrently running jobs, as in joblib.Parallel.
4878
49- bootstrap_type: 'percentile', 'pivot', or 'normal', default 'percentile '
79+ bootstrap_type: 'percentile', 'pivot', or 'normal', default 'pivot '
5080 Bootstrap method used to compute results.
5181 'percentile' will result in using the empiracal CDF of the replicated computations of the statistics.
5282 'pivot' will also use the replicates but create a pivot interval that also relies on the estimate
5383 over the entire dataset.
5484 'normal' will instead compute a pivot interval assuming the replicates are normally distributed.
5585 """
5686
57- def __init__ (self , n_bootstrap_samples = 100 , n_jobs = - 1 , bootstrap_type = 'percentile ' ):
87+ def __init__ (self , n_bootstrap_samples = 100 , n_jobs = - 1 , bootstrap_type = 'pivot ' ):
5888 self ._n_bootstrap_samples = n_bootstrap_samples
5989 self ._n_jobs = n_jobs
6090 self ._bootstrap_type = bootstrap_type
6191
6292 def fit (self , estimator , * args , ** kwargs ):
63- discrete_treatment = estimator ._discrete_treatment if hasattr (estimator , '_discrete_treatment' ) else False
6493 est = BootstrapEstimator (estimator , self ._n_bootstrap_samples , self ._n_jobs , compute_means = False ,
6594 bootstrap_type = self ._bootstrap_type )
6695 est .fit (* args , ** kwargs )
@@ -82,34 +111,6 @@ def wrapped(*args, alpha=0.1, **kwargs):
82111 else :
83112 return m
84113
85- def summary (self , alpha = 0.1 , value = 0 , decimals = 3 , feat_name = None ):
86- smry = Summary ()
87- try :
88- coef_table = self .coef__inference ().summary_frame (alpha = alpha ,
89- value = value , decimals = decimals , feat_name = feat_name )
90- coef_array = coef_table .values
91- coef_headers = [i + '\n ' +
92- j for (i , j ) in coef_table .columns ] if self .d_t > 1 else coef_table .columns .tolist ()
93- coef_stubs = [i + ' | ' + j for (i , j ) in coef_table .index ] if self .d_y > 1 else coef_table .index .tolist ()
94- coef_title = 'Coefficient Results'
95- smry .add_table (coef_array , coef_headers , coef_stubs , coef_title )
96- except Exception as e :
97- print ("Coefficient Results: " , str (e ))
98- try :
99- intercept_table = self .intercept__inference ().summary_frame (alpha = alpha ,
100- value = value , decimals = decimals , feat_name = None )
101- intercept_array = intercept_table .values
102- intercept_headers = [i + '\n ' + j for (i , j )
103- in intercept_table .columns ] if self .d_t > 1 else intercept_table .columns .tolist ()
104- intercept_stubs = [i + ' | ' + j for (i , j )
105- in intercept_table .index ] if self .d_y > 1 else intercept_table .index .tolist ()
106- intercept_title = 'Intercept Results'
107- smry .add_table (intercept_array , intercept_headers , intercept_stubs , intercept_title )
108- except Exception as e :
109- print ("Intercept Results: " , str (e ))
110- if len (smry .tables ) > 0 :
111- return smry
112-
113114
114115class GenericModelFinalInference (Inference ):
115116 """
@@ -215,7 +216,7 @@ def effect_inference(self, X, *, T0, T1):
215216 pred_stderr = e_stderr , inf_type = 'effect' , fname_transformer = None )
216217
217218
218- class LinearModelFinalInference (GenericModelFinalInference ):
219+ class LinearModelFinalInference (_SummaryMixin , GenericModelFinalInference ):
219220 """
220221 Inference based on predict_interval of the model_final model. Assumes that estimator
221222 class has a model_final method and that model is linear. Thus, the predict(cross_product(X, T1 - T0)) gives
@@ -318,34 +319,6 @@ def intercept__inference(self):
318319 return NormalInferenceResults (d_t = self .d_t , d_y = self .d_y , pred = intercept , pred_stderr = intercept_stderr ,
319320 inf_type = 'intercept' , fname_transformer = None )
320321
321- def summary (self , alpha = 0.1 , value = 0 , decimals = 3 , feat_name = None ):
322- smry = Summary ()
323- try :
324- coef_table = self .coef__inference ().summary_frame (alpha = alpha ,
325- value = value , decimals = decimals , feat_name = feat_name )
326- coef_array = coef_table .values
327- coef_headers = [i + '\n ' +
328- j for (i , j ) in coef_table .columns ] if self .d_t > 1 else coef_table .columns .tolist ()
329- coef_stubs = [i + ' | ' + j for (i , j ) in coef_table .index ] if self .d_y > 1 else coef_table .index .tolist ()
330- coef_title = 'Coefficient Results'
331- smry .add_table (coef_array , coef_headers , coef_stubs , coef_title )
332- except Exception as e :
333- print ("Coefficient Results: " , str (e ))
334- try :
335- intercept_table = self .intercept__inference ().summary_frame (alpha = alpha ,
336- value = value , decimals = decimals , feat_name = None )
337- intercept_array = intercept_table .values
338- intercept_headers = [i + '\n ' + j for (i , j )
339- in intercept_table .columns ] if self .d_t > 1 else intercept_table .columns .tolist ()
340- intercept_stubs = [i + ' | ' + j for (i , j )
341- in intercept_table .index ] if self .d_y > 1 else intercept_table .index .tolist ()
342- intercept_title = 'Intercept Results'
343- smry .add_table (intercept_array , intercept_headers , intercept_stubs , intercept_title )
344- except Exception as e :
345- print ("Intercept Results: " , str (e ))
346- if len (smry .tables ) > 0 :
347- return smry
348-
349322
350323class StatsModelsInference (LinearModelFinalInference ):
351324 """Stores statsmodels covariance options.
0 commit comments