Skip to content

Commit 6f1ec93

Browse files
author
zhangkaihuo
authored
Fix bn performance degradation (#50287)
* fix bn performance degradation
1 parent 8c14b02 commit 6f1ec93

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,8 @@ void BatchNormGradRawKernel(const Context &ctx,
783783
}
784784
// CUDNN only support small batch size
785785
bool use_native_nhwc =
786-
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
786+
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC &&
787+
H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL)
787788
: false;
788789
const bool use_native_kernel =
789790
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||

0 commit comments

Comments
 (0)