@@ -858,15 +858,20 @@ void BatchNormGradRawKernel(const Context &ctx,
858858// ctx.GetPlace()),
859859// epsilon, saved_mean_data, saved_var_data));
860860#else
861- // CUDNN only support small batch size
862- // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
863- const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240 ;
864- const size_t CUDNN_SPATIAL_THRESHOLD = 880801 ;
865- const bool use_native_kernel =
866- ((x_dims.size () == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
867- (x_dims.size () == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
868- if (use_native_kernel) {
869- if (x_dims.size () == 2 ) {
861+ }
862+ // CUDNN only support small batch size
863+ // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
864+ const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240 ;
865+ const size_t CUDNN_SPATIAL_THRESHOLD = 880801 ;
866+ bool use_native_nhwc =
867+ d_x ? (x_dims.size () == 4 && compute_format == DataLayout::kNHWC )
868+ : false ;
869+ const bool use_native_kernel =
870+ ((x_dims.size () == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
871+ (x_dims.size () == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
872+ if (use_native_nhwc || (d_x && d_scale && d_bias)) {
873+ if (use_native_kernel || use_native_nhwc) {
874+ if (x_dims.size () == 2 || use_native_nhwc) {
870875 dim3 block;
871876 dim3 grid;
872877 const int block_size = 512 ;
@@ -937,6 +942,21 @@ void BatchNormGradRawKernel(const Context &ctx,
937942 flag_ptr);
938943 }
939944 // 2. reduce_sum(x, dy, mean) => dscale, dbias
945+ BatchNormParamType<T> *dscale = nullptr ;
946+ BatchNormParamType<T> *dbias = nullptr ;
947+ bool with_scale = false ;
948+ if (d_scale && d_bias) {
949+ dscale = ctx.template Alloc <BatchNormParamType<T>>(d_scale);
950+ dbias = ctx.template Alloc <BatchNormParamType<T>>(d_bias);
951+ } else {
952+ DenseTensor dscale_mem =
953+ phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
954+ DenseTensor dbias_mem =
955+ phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
956+ dscale = dscale_mem.data <BatchNormParamType<T>>();
957+ dbias = dbias_mem.data <BatchNormParamType<T>>();
958+ }
959+
940960 BNBackward2DChannelLastStage2<T, block_size>
941961 <<<grid, block, 0 , ctx.stream()>>> (
942962 transformed_d_y.template data <T>(),
@@ -948,8 +968,8 @@ void BatchNormGradRawKernel(const Context &ctx,
948968 H * W * D,
949969 epsilon,
950970 block_data_ptr,
951- ctx. template Alloc <BatchNormParamType<T>>(d_scale) ,
952- ctx. template Alloc <BatchNormParamType<T>>(d_bias) ,
971+ dscale ,
972+ dbias ,
953973 flag_ptr);
954974
955975 // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
@@ -958,8 +978,8 @@ void BatchNormGradRawKernel(const Context &ctx,
958978 transformed_d_y.template data <T>(),
959979 transformed_x.template data <T>(),
960980 scale.template data <BatchNormParamType<T>>(),
961- d_scale-> data <BatchNormParamType<T>>() ,
962- d_bias-> data <BatchNormParamType<T>>() ,
981+ dscale ,
982+ dbias ,
963983 mean_ptr,
964984 variance_ptr,
965985 C,
@@ -1169,6 +1189,7 @@ void BatchNormGradRawKernel(const Context &ctx,
11691189 paddle::platform::dynload::cudnnDestroyTensorDescriptor (
11701190 bn_param_desc_));
11711191#endif
1192+
11721193 } else {
11731194 const auto *running_mean = mean.get_ptr ();
11741195 const auto *running_var = variance.get_ptr ();
0 commit comments