Skip to content

Conversation

@wangxicoding
Copy link
Contributor

@wangxicoding wangxicoding commented Jul 30, 2021

PR types

Performance optimization

PR changes

Others

Describe

optimize pipeline performance with recompute and amp.
Main optimization points

  1. Recompute doesn't set op_device attr to backward op, will led Pipeline send/recv sub grad.
    image
    After we add op_device attr to recompute backward op, Pipeline send/recv is ok.
    image
  2. Add concat and split into amp gray_list. Before we add, concat only run fp32.
    image
    After we add concat into amp gray_list, concat can run will fp16 if inputs is fp16
    image
  3. When use Pipeline, if op's device is all, this op will compute in all PipelineStage. While when we use AMP, AMP will insert cast op after this device:all op, and cast op's device is not all, the Pipeline will send cast var.
    image
    For optimize, we set this cast op as device:all if prev op is device:all, and Pipeline will not send this var.
    image

Performance test

Test with Ernie3.0, "hidden_size": 4096, "num_attention_heads": 128, "num_hidden_layers": 76, "num_sharing_layers": 64,
16cards 32G V100,mp=8, pp=2, amp, recompute, gbs=32, micro_bs=2

develop(tokens/s) PR(tokens/s) improve
2164.5 2445.2 12.96%

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@gongweibao gongweibao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

介绍一下做了哪些改动?为何能提升性能?提升多少性能?

@wangxicoding wangxicoding changed the title optimize pipeline performance with recompute and amp [hybrid] optimize pipeline performance with recompute and amp Aug 4, 2021
Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wangxicoding wangxicoding requested a review from zhiqiu August 5, 2021 06:18
Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for backward.py

@wangxicoding wangxicoding merged commit 911c859 into PaddlePaddle:develop Aug 5, 2021
@wangxicoding wangxicoding deleted the optimize_pipeline_recompute_amp branch August 5, 2021 07:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants