Skip to content

Commit 5f6da40

Browse files
author
Miruna Oprescu
committed
Add scores.
1 parent 1070aea commit 5f6da40

File tree

2 files changed

+89
-16
lines changed

2 files changed

+89
-16
lines changed

econml/dml/dynamic_dml.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ def predict(self, Y, T, X=None, W=None, sample_weight=None, groups=None):
9292
return Y_res, T_res
9393

9494
def score(self, Y, T, X=None, W=None, sample_weight=None, groups=None):
95-
# TODO: implement scores
96-
# TODO: fix correctness?
9795
assert Y.shape[0] % self.n_periods == 0, \
9896
"Length of training data should be an integer multiple of time periods."
9997
inds_score = np.arange(Y.shape[0])[np.arange(Y.shape[0]) % self.n_periods == 0]
@@ -147,7 +145,7 @@ def __init__(self, model_final, n_periods):
147145
self._model_final_trained = {k: clone(self._model_final, safe=False) for k in np.arange(n_periods)}
148146

149147
def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
150-
# TODO: handle sample weight, sample var
148+
# NOTE: sample weight, sample var are not passed in
151149
Y_res, T_res = nuisances
152150
self._d_y = Y.shape[1:]
153151
for kappa in np.arange(self.n_periods):
@@ -186,8 +184,29 @@ def predict(self, X=None):
186184
return preds
187185

188186
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, sample_var=None):
189-
# TODO: implement score
190-
return None
187+
assert Y.shape[0] % self.n_periods == 0, \
188+
"Length of training data should be an integer multiple of time periods."
189+
Y_res, T_res = nuisances
190+
191+
scores = np.full((self.n_periods, ), np.nan)
192+
for kappa in np.arange(self.n_periods):
193+
period = self.n_periods - 1 - kappa
194+
period_filter = self.period_filter_gen(period, Y.shape[0])
195+
Y_adj = Y_res[period_filter].copy()
196+
if kappa > 0:
197+
Y_adj -= np.sum(
198+
[self._model_final_trained[tau].predict_with_res(
199+
X[self.period_filter_gen(self.n_periods - 1 - tau, Y.shape[0])] if X is not None else None,
200+
T_res[period_filter, ..., self.n_periods - 1 - tau]
201+
) for tau in np.arange(kappa)], axis=0)
202+
Y_adj_pred = self._model_final_trained[kappa].predict_with_res(
203+
X[period_filter] if X is not None else None,
204+
T_res[period_filter, ..., period])
205+
if sample_weight is not None:
206+
scores[kappa] = np.mean(np.average((Y_adj - Y_adj_pred)**2, weights=sample_weight, axis=0))
207+
else:
208+
scores[kappa] = np.mean((Y_adj - Y_adj_pred) ** 2)
209+
return scores
191210

192211
def period_filter_gen(self, p, n):
193212
return (np.arange(n) % self.n_periods == p)
@@ -548,12 +567,39 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
548567
warn("This CATE estimator does not yet support sample weights and sample variance. "
549568
"These inputs will be ignored during fitting.",
550569
UserWarning)
551-
# TODO: support sample_weight, sample_var?
552570
return super().fit(Y, T, X=X, W=W,
553571
sample_weight=None, sample_var=None, groups=groups,
554572
cache_values=cache_values,
555573
inference=inference)
556574

575+
def score(self, Y, T, X=None, W=None):
576+
"""
577+
Score the fitted CATE model on a new data set. Generates nuisance parameters
578+
for the new data set based on the fitted residual nuisance models created at fit time.
579+
It uses the mean prediction of the models fitted by the different crossfit folds.
580+
Then calculates the MSE of the final residual Y on residual T regression.
581+
582+
If model_final does not have a score method, then it raises an :exc:`.AttributeError`
583+
584+
Parameters
585+
----------
586+
Y: (n, d_y) matrix or vector of length n
587+
Outcomes for each sample (required: n = n_groups * n_periods)
588+
T: (n, d_t) matrix or vector of length n
589+
Treatments for each sample (required: n = n_groups * n_periods)
590+
X: optional(n, d_x) matrix or None (Default=None)
591+
Features for each sample (Required: n = n_groups * n_periods)
592+
W: optional(n, d_w) matrix or None (Default=None)
593+
Controls for each sample (Required: n = n_groups * n_periods)
594+
595+
Returns
596+
-------
597+
score: float
598+
The MSE of the final CATE model on the new data.
599+
"""
600+
# Replacing score from _OrthoLearner, to enforce Z=None and improve the docstring
601+
return super().score(Y, T, X=X, W=W)
602+
557603
def cate_treatment_names(self, treatment_names=None):
558604
"""
559605
Get treatment names for each time period.
@@ -658,3 +704,34 @@ def model_final(self):
658704
def model_final(self, model):
659705
if model is not None:
660706
raise ValueError("Parameter `model_final` cannot be altered for this estimator!")
707+
708+
@property
709+
def models_y(self):
710+
return [[mdl._model_y for mdl in mdls] for mdls in super().models_nuisance_]
711+
712+
@property
713+
def models_t(self):
714+
return [[mdl._model_t for mdl in mdls] for mdls in super().models_nuisance_]
715+
716+
@property
717+
def nuisance_scores_y(self):
718+
return self.nuisance_scores_[0]
719+
720+
@property
721+
def nuisance_scores_t(self):
722+
return self.nuisance_scores_[1]
723+
724+
@property
725+
def residuals_(self):
726+
"""
727+
A tuple (y_res, T_res, X, W), of the residuals from the first stage estimation
728+
along with the associated X and W. Samples are not guaranteed to be in the same
729+
order as the input order.
730+
"""
731+
if not hasattr(self, '_cached_values'):
732+
raise AttributeError("Estimator is not fitted yet!")
733+
if self._cached_values is None:
734+
raise AttributeError("`fit` was called with `cache_values=False`. "
735+
"Set to `True` to enable residual storage.")
736+
Y_res, T_res = self._cached_values.nuisances
737+
return Y_res, T_res, self._cached_values.X, self._cached_values.W

econml/tests/test_dynamic_dml.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,11 @@ def make_random(n, is_discrete, d):
113113
np.testing.assert_allclose(
114114
marg_eff if d_x else marg_eff[0:1], const_marg_eff)
115115

116-
# TODO: add score and nuisance scores
117-
"""
118-
assert isinstance(est.score_, float)
119-
for score in est.nuisance_scores_y:
120-
assert isinstance(score, float)
121-
for score in est.nuisance_scores_t:
122-
assert isinstance(score, float)
123-
"""
116+
assert len(est.score_) == n_periods
117+
for score in est.nuisance_scores_y[0]:
118+
assert score.shape == (n_periods, )
119+
for score in est.nuisance_scores_t[0]:
120+
assert score.shape == (n_periods, n_periods)
124121

125122
T0 = np.full_like(T_test, 'a') if is_discrete else np.zeros_like(T_test)
126123
eff = est.effect(X, T0=T0, T1=T_test)
@@ -238,8 +235,7 @@ def make_random(n, is_discrete, d):
238235
[0], est.intercept__interval()[0], decimal=5)
239236

240237
est.summary()
241-
242-
# TODO: add score to estimator
238+
# TODO: fix score
243239
"""
244240
est.score(Y, T, X, W)
245241
"""

0 commit comments

Comments
 (0)