Skip to content

[NPU] flatten params and grads, fuse grad_clip and optimizer op#33461

Merged
zhiqiu merged 10 commits intoPaddlePaddle:developfrom
zhiqiu:dev/fuse_all_opt
Jun 21, 2021
Merged

[NPU] flatten params and grads, fuse grad_clip and optimizer op#33461
zhiqiu merged 10 commits intoPaddlePaddle:developfrom
zhiqiu:dev/fuse_all_opt

Conversation

@zhiqiu
Copy link
Contributor

@zhiqiu zhiqiu commented Jun 9, 2021

PR types

Performance optimization

PR changes

OPs

Describe

[NPU] flatten params and grads, fuse grad_clip and optimizer op

For example, ernie-3.0 model has 300+ parameters, and thus 300+ gradients of parameters.

Each training step, the program has to perform grad_clip the gradient and update the parameter. So, there are 300+ grad_clip operators and 300+ optimizer operators.
image

This PR tries to flatten all the parameters into one continuous memory space and also flatten the gradients. After that, some of the gradient clip and optimizer can be done by 1 time on the flattened parameter/gradient.
image

Currently, Adam + ClipByGlobalNorm is supported.

Performance

ernie-3.0, bs=20480, training speed: 20684->23583 tokens/s, +13.8%

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jun 9, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhiqiu zhiqiu requested review from phlrain and zhangting2020 June 17, 2021 05:41
@zhangting2020
Copy link
Contributor

paddle/optimzier/optimizer.py是不是也需要同步修改?

"""
Args:
flatten_param_grads (bool, optional): Whether to flatten all the parameters and grads.
If true, the parameters and gradients will be coalesce to continue mempry,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue mempry -> contiguous memory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants