@@ -237,7 +237,6 @@ struct KronGradElemFunctor<platform::complex<T>> {
237237 const int ndims_;
238238};
239239
240- template <typename T>
241240struct IdentityFunctor {
242241 HOSTDEVICE explicit inline IdentityFunctor () {}
243242
@@ -315,13 +314,13 @@ struct KronGradOpFunctor {
315314#if defined(__NVCC__) || defined(__HIPCC__)
316315 auto stream = dev_ctx.stream (); // it is a cuda device_context
317316 if (dx) {
318- TensorReduce<T, T, cub::Sum, IdentityFunctor<T> >(
319- dout_x, dx, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor<T> (),
317+ TensorReduce<T, T, cub::Sum, IdentityFunctor>(
318+ dout_x, dx, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor (),
320319 stream);
321320 }
322321 if (dy) {
323- TensorReduce<T, T, cub::Sum, IdentityFunctor<T> >(
324- dout_y, dy, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor<T> (),
322+ TensorReduce<T, T, cub::Sum, IdentityFunctor>(
323+ dout_y, dy, {1 }, static_cast <T>(0 ), cub::Sum (), IdentityFunctor (),
325324 stream);
326325 }
327326#else
0 commit comments