Skip to content

[3D-Parallel:Sharding] Optimizations for supporting ERNIE 3.0 training#31884

Merged
wangxicoding merged 25 commits intoPaddlePaddle:developfrom
JZ-LIANG:sharding-ERNIE160B-updates
Apr 2, 2021
Merged

[3D-Parallel:Sharding] Optimizations for supporting ERNIE 3.0 training#31884
wangxicoding merged 25 commits intoPaddlePaddle:developfrom
JZ-LIANG:sharding-ERNIE160B-updates

Conversation

@JZ-LIANG
Copy link
Contributor

@JZ-LIANG JZ-LIANG commented Mar 26, 2021

PR types

New features

PR changes

APIs

Describe

This pr is a major update to sharding and had changed sharding APIs.
It consists all new features and performance optimizations developed for 100B ERNIE 3.0 training.

The major updates are following:
Performance optimizations:

  1. fixed bug in Recompute optimizer: Recompute-Sharing related broadcast: FP32 --> FP16
  2. Sharding allreduce --> reduce: save 1/2 algorithm bandwidth need for sharding grad synchronization
  3. optimize the sharding initialization procedure
  4. support 2 sharding segment strategise (by broadcast size or by anchor)
  5. remove the unnecessary sync for logics in sharding supporting amp and clipbyglobalnorm

New features:

  1. Megatron-Sharding 2D parallelism
  2. Sharding Gradient Merge

features optimizations:

  1. add sync in startup program to avoid potential hang when create multiple nccl comm.
  2. uniform switch to switch among different parallelism mode

The new api:

  • sharding_segment_strategy: could be choose from "segment_broadcast_MB" and "segment_anchors"
    • segment is a concept used by sharding to overlap comm and calc.
    • segment_anchors: segment program by user defined anchors
    • segment_broadcast_MB: segment program by broadcast volume (MB)
  • sharding_degree:
    • the number of way of sharding parallelism
    • turn off sharding parallelism by setting sharding_degree = 1.
  • mp_degree:
    • the number of way of mp (Megatron) parallelism
    • turn off mp parallelism by setting mp_degree = 1.
  • hybrid_dp:
    • the data parallelism (distinguish with sharding) used to scale up training throughput as the outer parallelism
    • when hybrid_dp = True, user should ensure global_wold_size = N * mp_degree * sharding_degree (N >= 2), where the N is the data parallelism degree.

example:

  • assume we have 4 nodes with 8 gpus in each node:
  • pure sharding among all 32 gpus:
    dist_strategy.sharding = True
    dist_strategy.sharding_configs = {
        "sharding_segment_strategy": "segment_broadcast_MB",
        "segment_broadcast_MB": 32,
        "segment_anchors": None,
        "sharding_degree": 32,
        "mp_degree": 1,
        "hybrid_dp": False,
        "gradient_merge_acc_step": 1,
    }
  • sharding-hybrid-dp which sharding parameter within 8 gpus per node and using 4 ways data parallel to scale up training throughput and enable gradient merge which is acc steps 4.
    dist_strategy.sharding = True
    dist_strategy.sharding_configs = {
        "sharding_segment_strategy": "segment_broadcast_MB",
        "segment_broadcast_MB": 32,
        "segment_anchors": None,
        "sharding_degree": 8,
        "mp_degree": 1,
        "hybrid_dp": True,
        "gradient_merge_acc_step": 4,
    }
  • 2D megatron-sharding which megatron split parameter within 8 gpus in each node and using 4 ways sharding parallel to further distribute parameters to 4 shards.
    dist_strategy.sharding = True
    dist_strategy.sharding_configs = {
        "sharding_segment_strategy": "segment_broadcast_MB",
        "segment_broadcast_MB": 32,
        "segment_anchors": None,
        "sharding_degree": 4,
        "mp_degree": 8,
        "hybrid_dp": False,
        "gradient_merge_acc_step": 1,
    }
  • megatron-sharding hybrid dp mode: with 4-megatron 2-sharding parallelism in each node. and 4 nodes duplicated for data parallelism.
    dist_strategy.sharding = True
    dist_strategy.sharding_configs = {
        "sharding_segment_strategy": "segment_broadcast_MB",
        "segment_broadcast_MB": 32,
        "segment_anchors": None,
        "sharding_degree": 4,
        "mp_degree": 2,
        "hybrid_dp": True,
        "gradient_merge_acc_step": 1,
    }

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@JZ-LIANG JZ-LIANG changed the title [Sharding] ALL Optimizations for supporting 16B ERNIE 3.0 training [Sharding] Optimizations for supporting ERNIE 3.0 training Mar 26, 2021
@JZ-LIANG JZ-LIANG force-pushed the sharding-ERNIE160B-updates branch from bc35a69 to 0abe6e9 Compare March 29, 2021 03:03
@JZ-LIANG JZ-LIANG force-pushed the sharding-ERNIE160B-updates branch from ac47cea to 7659235 Compare March 29, 2021 07:31
@JZ-LIANG JZ-LIANG force-pushed the sharding-ERNIE160B-updates branch from 7659235 to 726525c Compare March 29, 2021 07:33
@JZ-LIANG JZ-LIANG force-pushed the sharding-ERNIE160B-updates branch from 190a067 to cf5b1c9 Compare March 30, 2021 03:05

message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional float segment_broadcast_MB = 1 [ default = 32.0 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add simple comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those args decide how sharding segment the program, which will affect the comm-calc overlap logic in sharding.
by now, we support two segment strategy:
"segment_broadcast_MB": segment by broadcast volume
"segment_anchors": segment by user defined anchors(op' s output)
I will add detail explanation in fluidDoc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, next PR can add comments or link in this code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded.

# computation by split the check_finite_and_unscale op.
is_distributed = self.role_maker._worker_num() > 1
if self.user_defined_strategy.sharding:
# if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't get it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharding as well as sharidng-megatron do not support the pure_fp16 allreduce logic introduced in amp&purefp16 metaoptimizer, so we need to set "is_distributed" = False while sharding or megatron enable.
This PR will add logic for sharding-megatron supporting, but megatron meta-optimizer will be add in another pr, so I will remove this line for now.


comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32])
# comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean up ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

'elementwise_max', 'elementwise_div', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum',
'momentum'
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

# we should create the rename var in subprog, otherwise its VarType will be BOOL
block.create_var(
name=var_name_dict[name],
shape=block.program.global_block().var(name).shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block.program.global_block().var(name) is called four times.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revised ~

return True
return False

def is_gradient_merge_vars(var):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we recorded these kinds of hard coding or rules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.
this naive rule should be updated later.
the problem is that: grad@gradientmerge should be persistable in global scope but not to be saved. we need design a method to distinguish persistable-non-savable vars with persistable-savable vars

if var_name in vars_status:
vars_status[var_name] = 2
elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
if op.all_attrs()["use_calc_stream"] == False:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some simple comments here?

Copy link
Contributor Author

@JZ-LIANG JZ-LIANG Mar 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done~
we should ensure all sharding-related grad communication (reduce / allreduce) be sync before grad being used in optimizers.

but we should ignore and skip allreduce op of Megatron, since them are schedule in calc stream and would not have non-sync problem before next usage.

ring_id = op.desc.attr("ring_id")
var_name = op.desc.input_arg_names()[0]
param = var_name.split("@")[0]
if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some simple comments here?

Copy link
Contributor Author

@JZ-LIANG JZ-LIANG Mar 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done~
this problems was introduced by we want overlap the sharding grad-communication with backward calculation.
sharding use both allreduce and reduce to sync grad, we should ensure all sharding-related grad communication (reduce / allreduce) be sync before grad being used in optimizers.

@JZ-LIANG JZ-LIANG force-pushed the sharding-ERNIE160B-updates branch from ed5e936 to dde7d24 Compare March 31, 2021 06:50
gongweibao
gongweibao previously approved these changes Mar 31, 2021
Copy link
Contributor

@gongweibao gongweibao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# dp here is the pure dp as the outest parallelism
self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree //
self.sharding_degree)
assert self.role_maker._worker_num(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give an explanation of why assert fails?

self._startup_program, self.current_endpoint,
self.mp_group_endpoints, self.mp_rank, self.mp_ring_id, False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need process the else condition?

check_broadcast(main_block)
check_allreduce_sum(main_block, self._shard, self.dp_ring_id)
# # check op dependecy
# check_broadcast(main_block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# FIXME

@JZ-LIANG JZ-LIANG force-pushed the sharding-ERNIE160B-updates branch from c0aa6b6 to 60be6ec Compare March 31, 2021 08:09
@JZ-LIANG JZ-LIANG requested a review from gongweibao March 31, 2021 08:26
@JZ-LIANG JZ-LIANG changed the title [Sharding] Optimizations for supporting ERNIE 3.0 training [3D-Parallel:Sharding] Optimizations for supporting ERNIE 3.0 training Mar 31, 2021

message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
optional float segment_broadcast_MB = 1 [ default = 32.0 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, next PR can add comments or link in this code

optional float segment_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_group_size = 3 [ default = 8 ];
optional int32 sharding_degree = 3 [ default = 8 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why set default = 8, maybe can set to -1 or 0, which means get value from world_size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded. update in next pr

optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
[ default = 'segment_broadcast_MB' ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add all enum value in comments.
Can the strategy name be simplified, one is called by_broadcast(or by_param_size or other better names) and the other is by_anchors(or other better names ),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great suggestion! recorded

set([param for param, worker_idx in shard.global_param2device.items() \
if worker_idx == shard.worker_idx]))
assert to_check_param == should_check_param, "amp \
check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_finite_and_unscale --> {op.type}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded. updated in next pr

def __init__(self, sharding_ring_id):
self.sharding_ring_id = sharding_ring_id
def __init__(self, mp_ring_id):
self.mp_ring_id = mp_ring_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need comments mp_ring is (sharding+mp)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recorded

if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe better before c_allreduce_sum, for c_allreduce_sum may out inf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since grad tend to be a decimal value, i think scale after allreduce would be better to avoid ”Arithmetic underflow“.

outputs={'Out': sync_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need sync_calc_stream

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see in the next comment.

False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can move into _init_communicator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there will be a update that move all comm init related sync into _init_communicator function in next pr. recorded it.

type='conditional_block',
inputs={
'Cond': cond,
'Input': [],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

666

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in gradient merge senario, there is not need to bring the temp var in optimizer block scope back to global block scope , since those temp var will only be used in optimize procedure.


class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "3"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

666

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in mp-sharding or sharding-hybrid-dp setting, we need at least 4 workers to setting up the parallelism logic.

Copy link
Contributor

@wangxicoding wangxicoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wangxicoding wangxicoding merged commit 69c874f into PaddlePaddle:develop Apr 2, 2021
@wangxicoding
Copy link
Contributor

image
Need fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants