Skip to content

Conversation

@zhiqiu
Copy link
Contributor

@zhiqiu zhiqiu commented Jun 23, 2021

PR types

Bug fixes

PR changes

OPs

Describe

fix bug when the cuda kernel config exceeds dims max
close PaddlePaddle/PaddleClas#685, #33107

When use 3D dim3 config for the blocks of cuda kernel, the limit of each dim is (1024, 1024, 64). For example

dim3 num_blocks(x,y,z), num_threads(32, 8, 1);
kernel<<<num_blocks, num_threads>>>(param)

The condition shouble be: x <=1024, y <=1024, z <=64, otherwise, "cudaErrorInvalidConfiguration" will be raised.

In layer_norm_grad, it uses 3D dim3 (1, batch_size, 1) as blocks config. While, batch_size may > 1024.
This PR change 3D dim3 (1, batch_size, 1) to 1D dim3 (batch_size) to solve the problem.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhiqiu zhiqiu force-pushed the dev/fix_layer_norm_grad branch from 0b5ea5e to 617e3ed Compare June 23, 2021 14:54
Copy link
Contributor

@pangyoki pangyoki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit 56692f6 into PaddlePaddle:develop Jun 24, 2021
zhiqiu added a commit to zhiqiu/Paddle that referenced this pull request Jul 1, 2021
lanxianghit pushed a commit that referenced this pull request Jul 1, 2021
…33748) (#33893)

fix bug when the cuda kernel config exceeds dims max
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants