File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -382,8 +382,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
382382 }
383383
384384 // Run training mode.
385- // obtain running mean and running inv var, and see if we need to
386- // initialize them.
385+ // obtain running mean and running inv var, and there is no need
386+ // to initialize them.
387387
388388 auto *mean_out = ctx.Output <Tensor>(" MeanOut" );
389389 auto *variance_out = ctx.Output <Tensor>(" VarianceOut" );
@@ -394,10 +394,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
394394 auto *saved_variance = ctx.Output <Tensor>(" SavedVariance" );
395395 saved_mean->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
396396 saved_variance->mutable_data <BatchNormParamType<T>>(ctx.GetPlace ());
397- math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
398- functor;
399- functor (dev_ctx, saved_mean, static_cast <BatchNormParamType<T>>(0 ));
400- functor (dev_ctx, saved_variance, static_cast <BatchNormParamType<T>>(0 ));
401397
402398 if ((N * H * W * D) == 1 ) {
403399 // Only 1 element in normalization dimension,
You can’t perform that action at this time.
0 commit comments