[3D-Parallel:Sharding] Optimizations for supporting ERNIE 3.0 training#31884
Conversation
|
Thanks for your contribution! |
bc35a69 to
0abe6e9
Compare
ac47cea to
7659235
Compare
7659235 to
726525c
Compare
190a067 to
cf5b1c9
Compare
|
|
||
| message ShardingConfig { | ||
| optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; | ||
| optional float segment_broadcast_MB = 1 [ default = 32.0 ]; |
There was a problem hiding this comment.
those args decide how sharding segment the program, which will affect the comm-calc overlap logic in sharding.
by now, we support two segment strategy:
"segment_broadcast_MB": segment by broadcast volume
"segment_anchors": segment by user defined anchors(op' s output)
I will add detail explanation in fluidDoc.
There was a problem hiding this comment.
Also, next PR can add comments or link in this code
| # computation by split the check_finite_and_unscale op. | ||
| is_distributed = self.role_maker._worker_num() > 1 | ||
| if self.user_defined_strategy.sharding: | ||
| # if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel: |
There was a problem hiding this comment.
sharding as well as sharidng-megatron do not support the pure_fp16 allreduce logic introduced in amp&purefp16 metaoptimizer, so we need to set "is_distributed" = False while sharding or megatron enable.
This PR will add logic for sharding-megatron supporting, but megatron meta-optimizer will be add in another pr, so I will remove this line for now.
|
|
||
| comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, | ||
| ring_id, [inf_var_fp32]) | ||
| # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, |
| 'elementwise_max', 'elementwise_div', 'elementwise_mul', | ||
| 'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum', | ||
| 'momentum' | ||
| 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', |
python/paddle/fluid/backward.py
Outdated
| # we should create the rename var in subprog, otherwise its VarType will be BOOL | ||
| block.create_var( | ||
| name=var_name_dict[name], | ||
| shape=block.program.global_block().var(name).shape, |
There was a problem hiding this comment.
block.program.global_block().var(name) is called four times.
| return True | ||
| return False | ||
|
|
||
| def is_gradient_merge_vars(var): |
There was a problem hiding this comment.
Have we recorded these kinds of hard coding or rules?
There was a problem hiding this comment.
yes.
this naive rule should be updated later.
the problem is that: grad@gradientmerge should be persistable in global scope but not to be saved. we need design a method to distinguish persistable-non-savable vars with persistable-savable vars
| if var_name in vars_status: | ||
| vars_status[var_name] = 2 | ||
| elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": | ||
| if op.all_attrs()["use_calc_stream"] == False: |
There was a problem hiding this comment.
Add some simple comments here?
There was a problem hiding this comment.
done~
we should ensure all sharding-related grad communication (reduce / allreduce) be sync before grad being used in optimizers.
but we should ignore and skip allreduce op of Megatron, since them are schedule in calc stream and would not have non-sync problem before next usage.
| ring_id = op.desc.attr("ring_id") | ||
| var_name = op.desc.input_arg_names()[0] | ||
| param = var_name.split("@")[0] | ||
| if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": |
There was a problem hiding this comment.
Add some simple comments here?
There was a problem hiding this comment.
done~
this problems was introduced by we want overlap the sharding grad-communication with backward calculation.
sharding use both allreduce and reduce to sync grad, we should ensure all sharding-related grad communication (reduce / allreduce) be sync before grad being used in optimizers.
ed5e936 to
dde7d24
Compare
| # dp here is the pure dp as the outest parallelism | ||
| self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree // | ||
| self.sharding_degree) | ||
| assert self.role_maker._worker_num( |
There was a problem hiding this comment.
Give an explanation of why assert fails?
| self._startup_program, self.current_endpoint, | ||
| self.mp_group_endpoints, self.mp_rank, self.mp_ring_id, False) | ||
| append_naive_sync(startup_block, self.startup_prog_sync_var, | ||
| self.global_ring_id) |
There was a problem hiding this comment.
We need process the else condition?
| check_broadcast(main_block) | ||
| check_allreduce_sum(main_block, self._shard, self.dp_ring_id) | ||
| # # check op dependecy | ||
| # check_broadcast(main_block) |
c0aa6b6 to
60be6ec
Compare
|
|
||
| message ShardingConfig { | ||
| optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; | ||
| optional float segment_broadcast_MB = 1 [ default = 32.0 ]; |
There was a problem hiding this comment.
Also, next PR can add comments or link in this code
| optional float segment_broadcast_MB = 1 [ default = 32.0 ]; | ||
| optional bool hybrid_dp = 2 [ default = false ]; | ||
| optional int32 sharding_group_size = 3 [ default = 8 ]; | ||
| optional int32 sharding_degree = 3 [ default = 8 ]; |
There was a problem hiding this comment.
why set default = 8, maybe can set to -1 or 0, which means get value from world_size
There was a problem hiding this comment.
recorded. update in next pr
| optional int32 sharding_degree = 3 [ default = 8 ]; | ||
| optional int32 mp_degree = 4 [ default = 1 ]; | ||
| optional string sharding_segment_strategy = 5 | ||
| [ default = 'segment_broadcast_MB' ]; |
There was a problem hiding this comment.
Add all enum value in comments.
Can the strategy name be simplified, one is called by_broadcast(or by_param_size or other better names) and the other is by_anchors(or other better names ),
There was a problem hiding this comment.
great suggestion! recorded
| set([param for param, worker_idx in shard.global_param2device.items() \ | ||
| if worker_idx == shard.worker_idx])) | ||
| assert to_check_param == should_check_param, "amp \ | ||
| check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format( |
There was a problem hiding this comment.
check_finite_and_unscale --> {op.type}
There was a problem hiding this comment.
recorded. updated in next pr
| def __init__(self, sharding_ring_id): | ||
| self.sharding_ring_id = sharding_ring_id | ||
| def __init__(self, mp_ring_id): | ||
| self.mp_ring_id = mp_ring_id |
There was a problem hiding this comment.
Need comments mp_ring is (sharding+mp)
| if pure_dp_degree > 1: | ||
| block._insert_op_without_sync( | ||
| idx + 2, | ||
| type='scale', |
There was a problem hiding this comment.
maybe better before c_allreduce_sum, for c_allreduce_sum may out inf
There was a problem hiding this comment.
since grad tend to be a decimal value, i think scale after allreduce would be better to avoid ”Arithmetic underflow“.
| outputs={'Out': sync_var}, | ||
| attrs={ | ||
| 'ring_id': ring_id, | ||
| 'use_calc_stream': True, |
There was a problem hiding this comment.
Maybe need sync_calc_stream
There was a problem hiding this comment.
see in the next comment.
| False, | ||
| global_ring_id=self.global_ring_id, | ||
| sync=False) | ||
| append_naive_sync(startup_block, self.startup_prog_sync_var, |
There was a problem hiding this comment.
Can move into _init_communicator
There was a problem hiding this comment.
yes, there will be a update that move all comm init related sync into _init_communicator function in next pr. recorded it.
| type='conditional_block', | ||
| inputs={ | ||
| 'Cond': cond, | ||
| 'Input': [], |
There was a problem hiding this comment.
in gradient merge senario, there is not need to bring the temp var in optimizer block scope back to global block scope , since those temp var will only be used in optimize procedure.
|
|
||
| class TestFleetMetaOptimizer(TestFleetMetaOptimizer): | ||
| def setUp(self): | ||
| os.environ["PADDLE_TRAINER_ID"] = "3" |
There was a problem hiding this comment.
in mp-sharding or sharding-hybrid-dp setting, we need at least 4 workers to setting up the parallelism logic.

PR types
New features
PR changes
APIs
Describe
This pr is a major update to sharding and had changed sharding APIs.
It consists all new features and performance optimizations developed for 100B ERNIE 3.0 training.
The major updates are following:
Performance optimizations:
New features:
features optimizations:
The new api:
example: