Skip to content

Commit 8084c0a

Browse files
committed
Allow specifying n_rows in const_marginal_effect
1 parent 48613ea commit 8084c0a

File tree

5 files changed

+83
-36
lines changed

5 files changed

+83
-36
lines changed

econml/_ortho_learner.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -613,22 +613,25 @@ def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight
613613
sample_weight=sample_weight,
614614
sample_var=sample_var))
615615

616-
def const_marginal_effect(self, X=None):
616+
def const_marginal_effect(self, X=None, *, n_rows=None):
617617
self._check_fitted_dims(X)
618618
if X is None:
619-
return self._model_final.predict()
619+
pred = self._model_final.predict()
620+
return pred if n_rows is None else np.repeat(pred, n_rows, axis=0)
620621
else:
622+
if n_rows is not None:
623+
assert shape(X)[0] == n_rows
621624
return self._model_final.predict(X)
622625
const_marginal_effect.__doc__ = LinearCateEstimator.const_marginal_effect.__doc__
623626

624-
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
627+
def const_marginal_effect_interval(self, X=None, *, alpha=0.1, n_rows=None):
625628
self._check_fitted_dims(X)
626-
return super().const_marginal_effect_interval(X, alpha=alpha)
629+
return super().const_marginal_effect_interval(X, alpha=alpha, n_rows=n_rows)
627630
const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__
628631

629-
def const_marginal_effect_inference(self, X=None):
632+
def const_marginal_effect_inference(self, X=None, *, n_rows=None):
630633
self._check_fitted_dims(X)
631-
return super().const_marginal_effect_inference(X)
634+
return super().const_marginal_effect_inference(X, n_rows=n_rows)
632635
const_marginal_effect_inference.__doc__ = LinearCateEstimator.const_marginal_effect_inference.__doc__
633636

634637
def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):

econml/cate_estimator.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class LinearCateEstimator(BaseCateEstimator):
286286
"""Base class for all CATE estimators with linear treatment effects in this package."""
287287

288288
@abc.abstractmethod
289-
def const_marginal_effect(self, X=None):
289+
def const_marginal_effect(self, X=None, *, n_rows=None):
290290
"""
291291
Calculate the constant marginal CATE :math:`\\theta(·)`.
292292
@@ -297,6 +297,10 @@ def const_marginal_effect(self, X=None):
297297
----------
298298
X: optional (m, d_x) matrix or None (Default=None)
299299
Features for each sample.
300+
n_rows: optional int
301+
Number of rows to return if X is None; if no number of rows is specified and X is None
302+
then 1 row will be returned. If X is not None and a value is provided for the number of rows,
303+
that value must agree with the number of rows in X.
300304
301305
Returns
302306
-------
@@ -379,28 +383,22 @@ def marginal_effect(self, T, X=None):
379383

380384
def marginal_effect_interval(self, T, X=None, *, alpha=0.1):
381385
X, T = self._expand_treatments(X, T)
382-
effs = self.const_marginal_effect_interval(X=X, alpha=alpha)
383-
return tuple(np.repeat(eff, shape(T)[0], axis=0) if X is None else eff
384-
for eff in effs)
386+
if X is not None:
387+
return self.const_marginal_effect_interval(X=X, alpha=alpha)
388+
else: # need to pass the number of rows of T to ensure the right shape
389+
return self.const_marginal_effect_interval(X=X, alpha=alpha, n_rows=shape(T)[0])
385390
marginal_effect_interval.__doc__ = BaseCateEstimator.marginal_effect_interval.__doc__
386391

387392
def marginal_effect_inference(self, T, X=None):
388393
X, T = self._expand_treatments(X, T)
389-
cme_inf = self.const_marginal_effect_inference(X=X)
390-
pred = cme_inf.point_estimate
391-
pred_stderr = cme_inf.stderr
392-
if X is None:
393-
pred = np.repeat(pred, shape(T)[0], axis=0)
394-
pred_stderr = np.repeat(pred_stderr, shape(T)[0], axis=0)
395-
# TODO: It seems wrong to return inference results based on a normal approximation
396-
# even in the case where const_marginal_effect_inference was actually generated
397-
# using bootstrap
398-
return NormalInferenceResults(d_t=cme_inf.d_t, d_y=cme_inf.d_y, pred=pred,
399-
pred_stderr=pred_stderr, inf_type='effect', fname_transformer=None)
394+
if X is not None:
395+
return self.const_marginal_effect_inference(X=X)
396+
else: # need to pass the number of rows of T to ensure the right shape
397+
return self.const_marginal_effect_inference(X=X, n_rows=shape(T)[0])
400398
marginal_effect_inference.__doc__ = BaseCateEstimator.marginal_effect_inference.__doc__
401399

402400
@BaseCateEstimator._defer_to_inference
403-
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
401+
def const_marginal_effect_interval(self, X=None, *, alpha=0.1, n_rows=None):
404402
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
405403
by the model. Available only when ``inference`` is not ``None``, when
406404
calling the fit method.
@@ -412,6 +410,10 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
412410
alpha: optional float in [0, 1] (Default=0.1)
413411
The overall level of confidence of the reported interval.
414412
The alpha/2, 1-alpha/2 confidence interval is reported.
413+
n_rows: optional int
414+
Number of rows to return if X is None; if no number of rows is specified and X is None
415+
then 1 row will be returned. If X is not None and a value is provided for the number of rows,
416+
that value must agree with the number of rows in X.
415417
416418
Returns
417419
-------
@@ -422,7 +424,7 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
422424
pass
423425

424426
@BaseCateEstimator._defer_to_inference
425-
def const_marginal_effect_inference(self, X=None):
427+
def const_marginal_effect_inference(self, X=None, *, n_rows=None):
426428
""" Inference results for the quantities :math:`\\theta(X)` produced
427429
by the model. Available only when ``inference`` is not ``None``, when
428430
calling the fit method.
@@ -431,6 +433,10 @@ def const_marginal_effect_inference(self, X=None):
431433
----------
432434
X: optional (m, d_x) matrix or None (Default=None)
433435
Features for each sample
436+
n_rows: optional int
437+
Number of rows to return if X is None; if no number of rows is specified and X is None
438+
then 1 row will be returned. If X is not None and a value is provided for the number of rows,
439+
that value must agree with the number of rows in X.
434440
435441
Returns
436442
-------

econml/inference.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,26 @@ def fit(self, estimator, *args, **kwargs):
133133
self.d_t = self._d_t[0] if self._d_t else 1
134134
self.d_y = self._d_y[0] if self._d_y else 1
135135

136-
def const_marginal_effect_interval(self, X, *, alpha=0.1):
136+
def const_marginal_effect_interval(self, X, *, alpha=0.1, n_rows=None):
137+
assert X is None or n_rows is None or n_rows == shape(X)[0]
138+
repeat_X = X is None and n_rows is not None
137139
if X is None:
138140
X = np.ones((1, 1))
139141
elif self.featurizer is not None:
140142
X = self.featurizer.transform(X)
141143
X, T = broadcast_unit_treatments(X, self.d_t)
142144
preds = self._predict_interval(cross_product(X, T), alpha=alpha)
143-
return tuple(reshape_treatmentwise_effects(pred, self._d_t, self._d_y)
144-
for pred in preds)
145+
preds = tuple(reshape_treatmentwise_effects(pred, self._d_t, self._d_y)
146+
for pred in preds)
147+
if repeat_X:
148+
preds = tuple(np.repeat(pred, n_rows, axis=0)
149+
for pred in preds)
145150

146-
def const_marginal_effect_inference(self, X):
151+
return preds
152+
153+
def const_marginal_effect_inference(self, X, *, n_rows=None):
154+
assert X is None or n_rows is None or n_rows == shape(X)[0]
155+
repeat_X = X is None and n_rows is not None
147156
if X is None:
148157
X = np.ones((1, 1))
149158
elif self.featurizer is not None:
@@ -154,6 +163,9 @@ def const_marginal_effect_inference(self, X):
154163
raise AttributeError("Final model doesn't support prediction standard eror, "
155164
"please call const_marginal_effect_interval to get confidence interval.")
156165
pred_stderr = reshape_treatmentwise_effects(self._prediction_stderr(cross_product(X, T)), self._d_t, self._d_y)
166+
if repeat_X:
167+
pred = np.repeat(pred, n_rows, axis=0)
168+
pred_stderr = np.repeat(pred_stderr, n_rows, axis=0)
157169
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
158170
pred_stderr=pred_stderr, inf_type='effect', fname_transformer=None)
159171

@@ -370,23 +382,37 @@ def fit(self, estimator, *args, **kwargs):
370382
if hasattr(estimator, 'fit_cate_intercept'):
371383
self.fit_cate_intercept = estimator.fit_cate_intercept
372384

373-
def const_marginal_effect_interval(self, X, *, alpha=0.1):
385+
def const_marginal_effect_interval(self, X, *, alpha=0.1, n_rows=None):
386+
assert X is None or n_rows is None or n_rows == shape(X)[0]
387+
repeat_X = X is None and n_rows is not None
388+
374389
if (X is not None) and (self.featurizer is not None):
375390
X = self.featurizer.transform(X)
376391
preds = np.array([mdl.predict_interval(X, alpha=alpha) for mdl in self.fitted_models_final])
377-
return tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front
392+
preds = tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front
393+
if repeat_X:
394+
preds = tuple(np.repeat(pred, n_rows, axis=0) for pred in preds)
395+
return preds
396+
397+
def const_marginal_effect_inference(self, X, *, n_rows=None):
398+
assert X is None or n_rows is None or n_rows == shape(X)[0]
399+
repeat_X = X is None and n_rows is not None
378400

379-
def const_marginal_effect_inference(self, X):
380401
if (X is not None) and (self.featurizer is not None):
381402
X = self.featurizer.transform(X)
382403
pred = np.array([mdl.predict(X) for mdl in self.fitted_models_final])
404+
pred = np.moveaxis(pred, 0, -1) # send treatment to the end, pull bounds to the front
383405
if not hasattr(self.fitted_models_final[0], 'prediction_stderr'):
384406
raise AttributeError("Final model doesn't support prediction standard eror, "
385407
"please call const_marginal_effect_interval to get confidence interval.")
386408
pred_stderr = np.array([mdl.prediction_stderr(X) for mdl in self.fitted_models_final])
387-
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=np.moveaxis(pred, 0, -1),
388-
# send treatment to the end, pull bounds to the front
389-
pred_stderr=np.moveaxis(pred_stderr, 0, -1), inf_type='effect',
409+
pred_stderr = np.moveaxis(pred_stderr, 0, -1) # send treatment to the end, pull bounds to the front
410+
411+
if repeat_X:
412+
pred = np.repeat(pred, n_rows, axis=0)
413+
pred_stderr = np.repeat(pred_stderr, n_rows, axis=0)
414+
return NormalInferenceResults(d_t=self.d_t, d_y=self.d_y, pred=pred,
415+
pred_stderr=pred_stderr, inf_type='effect',
390416
fname_transformer=None)
391417

392418
def effect_interval(self, X, *, T0, T1, alpha=0.1):

econml/metalearners.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def fit(self, Y, T, X=None, *, inference=None):
165165
feat_arr = np.concatenate((X, T), axis=1)
166166
self.overall_model.fit(feat_arr, Y)
167167

168-
def const_marginal_effect(self, X=None):
168+
def const_marginal_effect(self, X=None, *, n_rows=None):
169169
"""Calculate the constant marginal treatment effect on a vector of features for each sample.
170170
171171
Parameters
@@ -180,6 +180,9 @@ def const_marginal_effect(self, X=None):
180180
Note that when Y is a vector rather than a 2-dimensional array,
181181
the corresponding singleton dimensions in the output will be collapsed
182182
"""
183+
assert X is None or n_rows is None or n_rows == shape(X)[0]
184+
repeat_X = X is None and n_rows is not None
185+
183186
# Check inputs
184187
if X is None:
185188
X = np.zeros((1, 1))
@@ -192,6 +195,8 @@ def const_marginal_effect(self, X=None):
192195
taus = (prediction - np.repeat(prediction[:, :, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, :, 1:]
193196
else:
194197
taus = (prediction - np.repeat(prediction[:, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, 1:]
198+
if repeat_X:
199+
taus = np.repeat(taus, n_rows, axis=0)
195200
return taus
196201

197202

econml/ortho_forest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ def fit(self, estimator, *args, **kwargs):
969969
self._T_vec = (T0.ndim == 1)
970970
return self
971971

972-
def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
972+
def const_marginal_effect_interval(self, X=None, *, alpha=0.1, n_rows=None):
973973
""" Confidence intervals for the quantities :math:`\\theta(X)` produced
974974
by the model. Available only when ``inference`` is ``blb``, when
975975
calling the fit method.
@@ -989,6 +989,9 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
989989
type of :meth:`const_marginal_effect(X)<const_marginal_effect>` )
990990
The lower and the upper bounds of the confidence interval for each quantity.
991991
"""
992+
assert X is None or n_rows is None or n_rows == shape(X)[0]
993+
repeat_X = X is None and n_rows is not None
994+
992995
params_and_cov = self._predict_wrapper(X)
993996
# Calculate confidence intervals for the parameter (marginal effect)
994997
lower = alpha / 2
@@ -1000,7 +1003,11 @@ def const_marginal_effect_interval(self, X=None, *, alpha=0.1):
10001003
param_lower, param_upper = np.asarray(param_lower), np.asarray(param_upper)
10011004
if self._T_vec:
10021005
# If T is a vector, preserve shape of the effect interval
1003-
return param_lower.flatten(), param_upper.flatten()
1006+
param_lower = param_lower.flatten()
1007+
param_upper = param_upper.flatten()
1008+
if repeat_X:
1009+
param_lower = np.repeat(param_lower, n_rows, axis=0)
1010+
param_upper = np.repeat(param_upper, n_rows, axis=0)
10041011
return param_lower, param_upper
10051012

10061013
def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.1):

0 commit comments

Comments
 (0)