-
Notifications
You must be signed in to change notification settings - Fork 6k
[3D-Parallel:Sharding] Optimizations for supporting ERNIE 3.0 training #31884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
14b07ba
e6b489d
c8b0f92
5444c3b
527bc96
7e35d31
2cc1b7c
f273420
6a18b38
98baf20
9ece14f
add91b7
e01e22a
ffb492b
0abe6e9
726525c
cb788cf
cf5b1c9
ceae74b
dde7d24
60be6ec
b414ddf
620f138
609859a
0732442
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,9 +29,14 @@ message RecomputeConfig { | |
| } | ||
|
|
||
| message ShardingConfig { | ||
| optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; | ||
| optional float segment_broadcast_MB = 1 [ default = 32.0 ]; | ||
| optional bool hybrid_dp = 2 [ default = false ]; | ||
| optional int32 sharding_group_size = 3 [ default = 8 ]; | ||
| optional int32 sharding_degree = 3 [ default = 8 ]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why set default = 8, maybe can set to -1 or 0, which means get value from world_size
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. recorded. update in next pr |
||
| optional int32 mp_degree = 4 [ default = 1 ]; | ||
| optional string sharding_segment_strategy = 5 | ||
| [ default = 'segment_broadcast_MB' ]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add all enum value in comments.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great suggestion! recorded |
||
| repeated string segment_anchors = 6; | ||
| optional int32 gradient_merge_acc_step = 7 [ default = 1 ]; | ||
| } | ||
|
|
||
| message AMPConfig { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,7 +73,7 @@ def remove_cast_op(block, params, segment, offset): | |
| @staticmethod | ||
| def prune_fp16(block, shard, reduced_grads_to_param, ring_id): | ||
| """ | ||
| 1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard | ||
| 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard | ||
| 2. revise amp inifine grad checking for sharding | ||
| """ | ||
| # remove cast | ||
|
|
@@ -103,6 +103,7 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): | |
| op._rename_input(inf_var_name, inf_var_name + "@sharding") | ||
| if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: | ||
| reversed_x = [] | ||
| reversed_x_paramname = [] | ||
| for input_name in op.desc.input('X'): | ||
| param_name = input_name.strip("@GRAD") | ||
| if param_name not in shard.global_params: | ||
|
|
@@ -111,12 +112,24 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): | |
| "be grads, but {} is not a grad".format(input_name)) | ||
| if shard.has_param(param_name): | ||
| reversed_x.append(input_name) | ||
| reversed_x_paramname.append(param_name) | ||
| op.desc.set_input('X', reversed_x) | ||
| op.desc.set_output('Out', reversed_x) | ||
|
|
||
| # the grad checking should take the all and only param in the current shard | ||
| to_check_param = set(reversed_x_paramname) | ||
| should_check_param = set(shard.global_params).intersection( | ||
| set([param for param, worker_idx in shard.global_param2device.items() \ | ||
| if worker_idx == shard.worker_idx])) | ||
| assert to_check_param == should_check_param, "amp \ | ||
| check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check_finite_and_unscale --> {op.type}
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. recorded. updated in next pr |
||
| should_check_param - to_check_param, | ||
| to_check_param - should_check_param) | ||
|
|
||
| if update_loss_scaling_op_idx == -1: | ||
| return | ||
| inf_var = block.var(inf_var_name) | ||
| inf_var_fp32 = block.create_var( | ||
| inf_var_int32 = block.create_var( | ||
| name=inf_var_name + "@cast_int32", | ||
| shape=inf_var.shape, | ||
| dtype=core.VarDesc.VarType.INT32) | ||
|
|
@@ -128,32 +141,30 @@ def prune_fp16(block, shard, reduced_grads_to_param, ring_id): | |
| update_loss_scaling_op_idx, | ||
| type='cast', | ||
| inputs={'X': inf_var}, | ||
| outputs={'Out': inf_var_fp32}, | ||
| outputs={'Out': inf_var_int32}, | ||
| attrs={ | ||
| "in_dtype": inf_var.dtype, | ||
| "out_dtype": inf_var_fp32.dtype, | ||
| "out_dtype": inf_var_int32.dtype, | ||
| OP_ROLE_KEY: OpRole.Optimize | ||
| }) | ||
| insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, | ||
| [inf_var_fp32]) | ||
| # this allreduce communication should not overlap with calc | ||
| block._insert_op_without_sync( | ||
| update_loss_scaling_op_idx + 2, | ||
| update_loss_scaling_op_idx + 1, | ||
| type='c_allreduce_max', | ||
| inputs={'X': inf_var_fp32}, | ||
| outputs={'Out': inf_var_fp32}, | ||
| attrs={'ring_id': ring_id, | ||
| OP_ROLE_KEY: OpRole.Optimize}) | ||
|
|
||
| comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, | ||
| ring_id, [inf_var_fp32]) | ||
|
|
||
| inputs={'X': inf_var_int32}, | ||
| outputs={'Out': inf_var_int32}, | ||
| attrs={ | ||
| 'ring_id': ring_id, | ||
| 'use_calc_stream': True, | ||
| OP_ROLE_KEY: OpRole.Optimize | ||
| }) | ||
| block._insert_op_without_sync( | ||
| update_loss_scaling_op_idx + 3 + comm_op_num, | ||
| update_loss_scaling_op_idx + 2, | ||
| type='cast', | ||
| inputs={'X': inf_var_fp32}, | ||
| inputs={'X': inf_var_int32}, | ||
| outputs={'Out': inf_var_sharding}, | ||
| attrs={ | ||
| "in_dtype": inf_var_fp32.dtype, | ||
| "in_dtype": inf_var_int32.dtype, | ||
| "out_dtype": inf_var_sharding.dtype, | ||
| OP_ROLE_KEY: OpRole.Optimize | ||
| }) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,21 +16,22 @@ | |
|
|
||
|
|
||
| class GradientClipHelper(object): | ||
| def __init__(self, sharding_ring_id): | ||
| self.sharding_ring_id = sharding_ring_id | ||
| def __init__(self, mp_ring_id): | ||
| self.mp_ring_id = mp_ring_id | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need comments mp_ring is (sharding+mp)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. recorded |
||
|
|
||
| def _is_gradient_clip_op(self, op): | ||
| return op.desc.has_attr("op_namescope") \ | ||
| and op.desc.attr("op_namescope").startswith("/gradient_clip") | ||
|
|
||
| def prune_gradient_clip(self, block, shard): | ||
| def prune_gradient_clip(self, block, shard, pure_dp_degree=1): | ||
| """ | ||
| prune gradient_clip related ops for params that not belong to cur shard | ||
| prune: square, reduce_sum, elementwise_mul | ||
| keep: sum, sqrt, elementwise_max, elementwise_div | ||
| """ | ||
| deperated_vars = set() | ||
| deperate_op_idx = set() | ||
| reversed_x_paramname = [] | ||
| for idx, op in enumerate(block.ops): | ||
| if not self._is_gradient_clip_op(op): | ||
| continue | ||
|
|
@@ -44,6 +45,8 @@ def prune_gradient_clip(self, block, shard): | |
| if shard.is_param(param_name) and \ | ||
| not shard.has_param(param_name): | ||
| deperate_op = True | ||
| elif shard.is_param(param_name): | ||
| reversed_x_paramname.append(param_name) | ||
|
|
||
| if deperate_op: | ||
| deperate_op_idx.add(idx) | ||
|
|
@@ -65,31 +68,48 @@ def prune_gradient_clip(self, block, shard): | |
| for input_name in op.desc.input_arg_names(): | ||
| if input_name not in deperated_vars: | ||
| reversed_inputs.append(input_name) | ||
|
|
||
| op.desc.set_input("X", reversed_inputs) | ||
| assert (len(op.desc.output_arg_names()) == 1) | ||
| sum_res = op.desc.output_arg_names()[0] | ||
| block._insert_op_without_sync( | ||
| idx + 1, | ||
| type='c_sync_comm_stream', | ||
| inputs={'X': sum_res}, | ||
| outputs={'Out': sum_res}, | ||
| attrs={'ring_id': 0, | ||
| OP_ROLE_KEY: OpRole.Optimize}) | ||
|
|
||
| # this allreduce should not overlap with calc and should be scheduled in calc stream | ||
| block._insert_op_without_sync( | ||
| idx + 1, | ||
| type='c_allreduce_sum', | ||
| inputs={'X': sum_res}, | ||
| outputs={'Out': sum_res}, | ||
| attrs={ | ||
| 'ring_id': self.sharding_ring_id, | ||
| OP_ROLE_KEY: OpRole.Optimize | ||
| 'ring_id': self.mp_ring_id, | ||
| 'op_namescope': "/gradient_clip_model_parallelism", | ||
| 'use_calc_stream': True, | ||
| OP_ROLE_KEY: OpRole.Optimize, | ||
| }) | ||
| block._insert_op_without_sync( | ||
| idx + 1, | ||
| type='c_sync_calc_stream', | ||
| inputs={'X': sum_res}, | ||
| outputs={'Out': sum_res}, | ||
| attrs={OP_ROLE_KEY: OpRole.Optimize}) | ||
|
|
||
| # global norm should only be sum within each model parallelism word size when use global group | ||
| if pure_dp_degree > 1: | ||
| block._insert_op_without_sync( | ||
| idx + 2, | ||
| type='scale', | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe better before
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since grad tend to be a decimal value, i think scale after allreduce would be better to avoid ”Arithmetic underflow“. |
||
| inputs={'X': sum_res}, | ||
| outputs={'Out': sum_res}, | ||
| attrs={ | ||
| 'scale': 1.0 / float(pure_dp_degree), | ||
| 'op_namescope': "/gradient_clip_model_parallelism", | ||
| 'bias': 0.0, | ||
| 'bias_after_scale': False, | ||
| OP_ROLE_KEY: OpRole.Optimize | ||
| }) | ||
|
|
||
| # the grad sum here should take the all and only param in the current shard | ||
| to_check_param = set(reversed_x_paramname) | ||
| should_check_param = set(shard.global_params).intersection(set( | ||
| [param for param, worker_idx in shard.global_param2device.items() \ | ||
| if worker_idx == shard.worker_idx])) | ||
| assert to_check_param == should_check_param, "amp check_finite_and_unscale \ | ||
| checking miss [{}] and got unexpected [{}]".format( | ||
| should_check_param - to_check_param, | ||
| to_check_param - should_check_param) | ||
|
|
||
| for var_name in deperated_vars: | ||
| block._remove_var(var_name, sync=False) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add simple comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those args decide how sharding segment the program, which will affect the comm-calc overlap logic in sharding.
by now, we support two segment strategy:
"segment_broadcast_MB": segment by broadcast volume
"segment_anchors": segment by user defined anchors(op' s output)
I will add detail explanation in fluidDoc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, next PR can add comments or link in this code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recorded.