@@ -235,11 +235,12 @@ def fit(self, Y, T, X, W, Z, inference=None):
235235 W = np .empty ((shape (Y )[0 ], 0 ))
236236 assert shape (Y )[0 ] == shape (T )[0 ] == shape (X )[0 ] == shape (W )[0 ] == shape (Z )[0 ]
237237
238+ # make T 2D if if was a vector
239+ if ndim (T ) == 1 :
240+ T = reshape (T , (- 1 , 1 ))
241+
238242 # store number of columns of W so that we can create correctly shaped zero array in effect and marginal effect
239243 self ._d_w = shape (W )[1 ]
240- # store number of columns of T so that we can pass scalars to effect
241- # TODO: support vector T and Y
242- self ._d_t = shape (T )[1 ]
243244
244245 # two stage approximation
245246 # first, get basis expansions of T, X, and Z
@@ -285,9 +286,13 @@ def effect(self, X=None, T0=0, T1=1):
285286
286287 """
287288 if ndim (T0 ) == 0 :
288- T0 = np .full ((1 if X is None else shape (X )[0 ], self ._d_t ) , T0 )
289+ T0 = np .full ((1 if X is None else shape (X )[0 ],) + self ._d_t , T0 )
289290 if ndim (T1 ) == 0 :
290- T1 = np .full ((1 if X is None else shape (X )[0 ], self ._d_t ), T1 )
291+ T1 = np .full ((1 if X is None else shape (X )[0 ],) + self ._d_t , T1 )
292+ if ndim (T0 ) == 1 :
293+ T0 = reshape (T0 , (- 1 , 1 ))
294+ if ndim (T1 ) == 1 :
295+ T1 = reshape (T1 , (- 1 , 1 ))
291296 if X is None :
292297 X = np .empty ((shape (T0 )[0 ], 0 ))
293298 assert shape (T0 ) == shape (T1 )
@@ -329,7 +334,7 @@ def marginal_effect(self, T, X=None):
329334
330335 ft_X = self ._x_featurizer .transform (X )
331336 n = shape (T )[0 ]
332- dT = self ._dt_featurizer .transform (T )
337+ dT = self ._dt_featurizer .transform (T if ndim ( T ) == 2 else reshape ( T , ( - 1 , 1 )) )
333338 W = np .zeros ((size (T ), self ._d_w ))
334339 # dT should be an n×dₜ×fₜ array (but if T was a vector, or if there is only one feature,
335340 # dT may be only 2-dimensional)
@@ -342,4 +347,8 @@ def marginal_effect(self, T, X=None):
342347 features = transpose (features , [0 , 1 , 3 , 2 ]) # swap last two dims to match cross_product
343348 features = reshape (features , (size (T ), - 1 ))
344349 output = self ._model_Y .predict (_add_zeros (np .hstack ([W , features ])))
345- return reshape (output , shape (T ) + (shape (output )[- 1 ],))
350+ output = reshape (output , shape (T ) + shape (output )[1 :])
351+ if ndim (output ) == 3 :
352+ return transpose (output , (0 , 2 , 1 )) # transpose trailing T and Y dims
353+ else :
354+ return output
0 commit comments