Skip to content

Commit 0294ab4

Browse files
author
zhangkaihuo
authored
Update threshold of bn1d (#49734)
1 parent 609b50a commit 0294ab4

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

paddle/phi/kernels/funcs/norm_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ limitations under the License. */
1818

1919
namespace phi {
2020
namespace funcs {
21+
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
22+
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
23+
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535
24+
2125
inline void ExtractNCWHD(const phi::DDim &dims,
2226
const DataLayout &data_layout,
2327
int *N,

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -907,15 +907,12 @@ void BatchNormGradRawKernel(const Context &ctx,
907907
#else
908908
}
909909
// CUDNN only support small batch size
910-
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
911-
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
912-
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
913910
bool use_native_nhwc =
914911
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
915912
: false;
916913
const bool use_native_kernel =
917914
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
918-
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
915+
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
919916
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
920917
if (use_native_kernel || use_native_nhwc) {
921918
if (x_dims.size() == 2 || use_native_nhwc) {

paddle/phi/kernels/gpu/batch_norm_kernel.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -722,9 +722,6 @@ void BatchNormKernel(const Context &ctx,
722722

723723
auto handle = ctx.cudnn_handle();
724724

725-
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
726-
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
727-
728725
// Now, depending on whether we are running test or not, we have two paths.
729726
// It is training mode when it's not reference AND not using pre-trained
730727
// model.
@@ -829,7 +826,7 @@ void BatchNormKernel(const Context &ctx,
829826
#else
830827
const bool use_native_kernel =
831828
(x_dims.size() == 2 ||
832-
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
829+
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
833830
if (use_native_kernel) {
834831
const int block_size = 256;
835832
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
@@ -1005,7 +1002,7 @@ void BatchNormKernel(const Context &ctx,
10051002
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
10061003
const bool use_native_kernel =
10071004
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
1008-
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
1005+
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
10091006
if (use_native_kernel) {
10101007
dim3 block;
10111008
dim3 grid;

0 commit comments

Comments
 (0)