@@ -92,8 +92,6 @@ def predict(self, Y, T, X=None, W=None, sample_weight=None, groups=None):
9292 return Y_res , T_res
9393
9494 def score (self , Y , T , X = None , W = None , sample_weight = None , groups = None ):
95- # TODO: implement scores
96- # TODO: fix correctness?
9795 assert Y .shape [0 ] % self .n_periods == 0 , \
9896 "Length of training data should be an integer multiple of time periods."
9997 inds_score = np .arange (Y .shape [0 ])[np .arange (Y .shape [0 ]) % self .n_periods == 0 ]
@@ -147,7 +145,7 @@ def __init__(self, model_final, n_periods):
147145 self ._model_final_trained = {k : clone (self ._model_final , safe = False ) for k in np .arange (n_periods )}
148146
149147 def fit (self , Y , T , X = None , W = None , Z = None , nuisances = None , sample_weight = None , sample_var = None ):
150- # TODO: handle sample weight, sample var
148+ # NOTE: sample weight, sample var are not passed in
151149 Y_res , T_res = nuisances
152150 self ._d_y = Y .shape [1 :]
153151 for kappa in np .arange (self .n_periods ):
@@ -186,8 +184,29 @@ def predict(self, X=None):
186184 return preds
187185
188186 def score (self , Y , T , X = None , W = None , Z = None , nuisances = None , sample_weight = None , sample_var = None ):
189- # TODO: implement score
190- return None
187+ assert Y .shape [0 ] % self .n_periods == 0 , \
188+ "Length of training data should be an integer multiple of time periods."
189+ Y_res , T_res = nuisances
190+
191+ scores = np .full ((self .n_periods , ), np .nan )
192+ for kappa in np .arange (self .n_periods ):
193+ period = self .n_periods - 1 - kappa
194+ period_filter = self .period_filter_gen (period , Y .shape [0 ])
195+ Y_adj = Y_res [period_filter ].copy ()
196+ if kappa > 0 :
197+ Y_adj -= np .sum (
198+ [self ._model_final_trained [tau ].predict_with_res (
199+ X [self .period_filter_gen (self .n_periods - 1 - tau , Y .shape [0 ])] if X is not None else None ,
200+ T_res [period_filter , ..., self .n_periods - 1 - tau ]
201+ ) for tau in np .arange (kappa )], axis = 0 )
202+ Y_adj_pred = self ._model_final_trained [kappa ].predict_with_res (
203+ X [period_filter ] if X is not None else None ,
204+ T_res [period_filter , ..., period ])
205+ if sample_weight is not None :
206+ scores [kappa ] = np .mean (np .average ((Y_adj - Y_adj_pred )** 2 , weights = sample_weight , axis = 0 ))
207+ else :
208+ scores [kappa ] = np .mean ((Y_adj - Y_adj_pred ) ** 2 )
209+ return scores
191210
192211 def period_filter_gen (self , p , n ):
193212 return (np .arange (n ) % self .n_periods == p )
@@ -548,12 +567,39 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
548567 warn ("This CATE estimator does not yet support sample weights and sample variance. "
549568 "These inputs will be ignored during fitting." ,
550569 UserWarning )
551- # TODO: support sample_weight, sample_var?
552570 return super ().fit (Y , T , X = X , W = W ,
553571 sample_weight = None , sample_var = None , groups = groups ,
554572 cache_values = cache_values ,
555573 inference = inference )
556574
575+ def score (self , Y , T , X = None , W = None ):
576+ """
577+ Score the fitted CATE model on a new data set. Generates nuisance parameters
578+ for the new data set based on the fitted residual nuisance models created at fit time.
579+ It uses the mean prediction of the models fitted by the different crossfit folds.
580+ Then calculates the MSE of the final residual Y on residual T regression.
581+
582+ If model_final does not have a score method, then it raises an :exc:`.AttributeError`
583+
584+ Parameters
585+ ----------
586+ Y: (n, d_y) matrix or vector of length n
587+ Outcomes for each sample (required: n = n_groups * n_periods)
588+ T: (n, d_t) matrix or vector of length n
589+ Treatments for each sample (required: n = n_groups * n_periods)
590+ X: optional(n, d_x) matrix or None (Default=None)
591+ Features for each sample (Required: n = n_groups * n_periods)
592+ W: optional(n, d_w) matrix or None (Default=None)
593+ Controls for each sample (Required: n = n_groups * n_periods)
594+
595+ Returns
596+ -------
597+ score: float
598+ The MSE of the final CATE model on the new data.
599+ """
600+ # Replacing score from _OrthoLearner, to enforce Z=None and improve the docstring
601+ return super ().score (Y , T , X = X , W = W )
602+
557603 def cate_treatment_names (self , treatment_names = None ):
558604 """
559605 Get treatment names for each time period.
@@ -658,3 +704,34 @@ def model_final(self):
658704 def model_final (self , model ):
659705 if model is not None :
660706 raise ValueError ("Parameter `model_final` cannot be altered for this estimator!" )
707+
708+ @property
709+ def models_y (self ):
710+ return [[mdl ._model_y for mdl in mdls ] for mdls in super ().models_nuisance_ ]
711+
712+ @property
713+ def models_t (self ):
714+ return [[mdl ._model_t for mdl in mdls ] for mdls in super ().models_nuisance_ ]
715+
716+ @property
717+ def nuisance_scores_y (self ):
718+ return self .nuisance_scores_ [0 ]
719+
720+ @property
721+ def nuisance_scores_t (self ):
722+ return self .nuisance_scores_ [1 ]
723+
724+ @property
725+ def residuals_ (self ):
726+ """
727+ A tuple (y_res, T_res, X, W), of the residuals from the first stage estimation
728+ along with the associated X and W. Samples are not guaranteed to be in the same
729+ order as the input order.
730+ """
731+ if not hasattr (self , '_cached_values' ):
732+ raise AttributeError ("Estimator is not fitted yet!" )
733+ if self ._cached_values is None :
734+ raise AttributeError ("`fit` was called with `cache_values=False`. "
735+ "Set to `True` to enable residual storage." )
736+ Y_res , T_res = self ._cached_values .nuisances
737+ return Y_res , T_res , self ._cached_values .X , self ._cached_values .W
0 commit comments