Skip to content

Commit 992250b

Browse files
Modified the Kernel policy. When the compute is NHWC (#48563)
1 parent 5c64d84 commit 992250b

File tree

1 file changed

+34
-13
lines changed

1 file changed

+34
-13
lines changed

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -858,15 +858,20 @@ void BatchNormGradRawKernel(const Context &ctx,
858858
// ctx.GetPlace()),
859859
// epsilon, saved_mean_data, saved_var_data));
860860
#else
861-
// CUDNN only support small batch size
862-
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
863-
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
864-
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
865-
const bool use_native_kernel =
866-
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
867-
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
868-
if (use_native_kernel) {
869-
if (x_dims.size() == 2) {
861+
}
862+
// CUDNN only support small batch size
863+
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
864+
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
865+
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
866+
bool use_native_nhwc =
867+
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
868+
: false;
869+
const bool use_native_kernel =
870+
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
871+
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
872+
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
873+
if (use_native_kernel || use_native_nhwc) {
874+
if (x_dims.size() == 2 || use_native_nhwc) {
870875
dim3 block;
871876
dim3 grid;
872877
const int block_size = 512;
@@ -937,6 +942,21 @@ void BatchNormGradRawKernel(const Context &ctx,
937942
flag_ptr);
938943
}
939944
// 2. reduce_sum(x, dy, mean) => dscale, dbias
945+
BatchNormParamType<T> *dscale = nullptr;
946+
BatchNormParamType<T> *dbias = nullptr;
947+
bool with_scale = false;
948+
if (d_scale && d_bias) {
949+
dscale = ctx.template Alloc<BatchNormParamType<T>>(d_scale);
950+
dbias = ctx.template Alloc<BatchNormParamType<T>>(d_bias);
951+
} else {
952+
DenseTensor dscale_mem =
953+
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
954+
DenseTensor dbias_mem =
955+
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
956+
dscale = dscale_mem.data<BatchNormParamType<T>>();
957+
dbias = dbias_mem.data<BatchNormParamType<T>>();
958+
}
959+
940960
BNBackward2DChannelLastStage2<T, block_size>
941961
<<<grid, block, 0, ctx.stream()>>>(
942962
transformed_d_y.template data<T>(),
@@ -948,8 +968,8 @@ void BatchNormGradRawKernel(const Context &ctx,
948968
H * W * D,
949969
epsilon,
950970
block_data_ptr,
951-
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
952-
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
971+
dscale,
972+
dbias,
953973
flag_ptr);
954974

955975
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
@@ -958,8 +978,8 @@ void BatchNormGradRawKernel(const Context &ctx,
958978
transformed_d_y.template data<T>(),
959979
transformed_x.template data<T>(),
960980
scale.template data<BatchNormParamType<T>>(),
961-
d_scale->data<BatchNormParamType<T>>(),
962-
d_bias->data<BatchNormParamType<T>>(),
981+
dscale,
982+
dbias,
963983
mean_ptr,
964984
variance_ptr,
965985
C,
@@ -1169,6 +1189,7 @@ void BatchNormGradRawKernel(const Context &ctx,
11691189
paddle::platform::dynload::cudnnDestroyTensorDescriptor(
11701190
bn_param_desc_));
11711191
#endif
1192+
11721193
} else {
11731194
const auto *running_mean = mean.get_ptr();
11741195
const auto *running_var = variance.get_ptr();

0 commit comments

Comments
 (0)