Skip to content

Commit 50c5613

Browse files
committed
Fix 2sls marginal effect
1 parent 223096b commit 50c5613

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

econml/tests/test_two_stage_least_squares.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def make_random(d):
6565
sz = (n, d) if d >= 0 else (n,)
6666
return np.random.normal(size=sz)
6767

68-
for d_t in [1, 2]:
68+
for d_t in [-1, 1, 2]:
6969
n_t = d_t if d_t > 0 else 1
70-
for d_y in [1, 2]:
70+
for d_y in [-1, 1, 2]:
7171
for d_x in [1, 5]:
7272
for d_z in [1, 2]:
7373
d_w = 1
@@ -80,9 +80,18 @@ def make_random(d):
8080
dt_featurizer=DPolynomialFeatures())
8181

8282
est.fit(Y, T, X, W, Z)
83+
8384
eff = est.effect(X)
8485
marg_eff = est.marginal_effect(T, X)
8586

87+
effect_shape = (n,) + ((d_y,) if d_y > 0 else ())
88+
marginal_effect_shape = ((n if d_x else 1,) +
89+
((d_y,) if d_y > 0 else ()) +
90+
((d_t,) if d_t > 0 else()))
91+
92+
self.assertEqual(shape(marg_eff), marginal_effect_shape)
93+
self.assertEqual(shape(eff), effect_shape)
94+
8695
def test_marg_eff(self):
8796
X = np.random.normal(size=(5000, 2))
8897
Z = np.random.normal(size=(5000, 2))

econml/two_stage_least_squares.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,12 @@ def fit(self, Y, T, X, W, Z, inference=None):
235235
W = np.empty((shape(Y)[0], 0))
236236
assert shape(Y)[0] == shape(T)[0] == shape(X)[0] == shape(W)[0] == shape(Z)[0]
237237

238+
# make T 2D if if was a vector
239+
if ndim(T) == 1:
240+
T = reshape(T, (-1, 1))
241+
238242
# store number of columns of W so that we can create correctly shaped zero array in effect and marginal effect
239243
self._d_w = shape(W)[1]
240-
# store number of columns of T so that we can pass scalars to effect
241-
# TODO: support vector T and Y
242-
self._d_t = shape(T)[1]
243244

244245
# two stage approximation
245246
# first, get basis expansions of T, X, and Z
@@ -285,9 +286,13 @@ def effect(self, X=None, T0=0, T1=1):
285286
286287
"""
287288
if ndim(T0) == 0:
288-
T0 = np.full((1 if X is None else shape(X)[0], self._d_t), T0)
289+
T0 = np.full((1 if X is None else shape(X)[0],) + self._d_t, T0)
289290
if ndim(T1) == 0:
290-
T1 = np.full((1 if X is None else shape(X)[0], self._d_t), T1)
291+
T1 = np.full((1 if X is None else shape(X)[0],) + self._d_t, T1)
292+
if ndim(T0) == 1:
293+
T0 = reshape(T0, (-1, 1))
294+
if ndim(T1) == 1:
295+
T1 = reshape(T1, (-1, 1))
291296
if X is None:
292297
X = np.empty((shape(T0)[0], 0))
293298
assert shape(T0) == shape(T1)
@@ -329,7 +334,7 @@ def marginal_effect(self, T, X=None):
329334

330335
ft_X = self._x_featurizer.transform(X)
331336
n = shape(T)[0]
332-
dT = self._dt_featurizer.transform(T)
337+
dT = self._dt_featurizer.transform(T if ndim(T) == 2 else reshape(T, (-1, 1)))
333338
W = np.zeros((size(T), self._d_w))
334339
# dT should be an n×dₜ×fₜ array (but if T was a vector, or if there is only one feature,
335340
# dT may be only 2-dimensional)
@@ -342,4 +347,8 @@ def marginal_effect(self, T, X=None):
342347
features = transpose(features, [0, 1, 3, 2]) # swap last two dims to match cross_product
343348
features = reshape(features, (size(T), -1))
344349
output = self._model_Y.predict(_add_zeros(np.hstack([W, features])))
345-
return reshape(output, shape(T) + (shape(output)[-1],))
350+
output = reshape(output, shape(T) + shape(output)[1:])
351+
if ndim(output) == 3:
352+
return transpose(output, (0, 2, 1)) # transpose trailing T and Y dims
353+
else:
354+
return output

0 commit comments

Comments
 (0)