Skip to content

Commit a4eadd1

Browse files
authored
[hybrid] Fix mp multi gradient clip prob (#35713)
1 parent 4b68388 commit a4eadd1

File tree

4 files changed

+127
-23
lines changed

4 files changed

+127
-23
lines changed

python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -142,32 +142,103 @@ def prune_gradient_clip(self, block, shard, ring_ids):
142142
return
143143

144144
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
145-
def sync_global_norm(self, block, ring_ids):
145+
def sync_global_norm(self, block, ring_ids, mp_rank):
146146
"""
147147
prune gradient_clip related ops for params that not belong to cur shard
148148
prune: square, reduce_sum, elementwise_mul
149149
keep: sum, sqrt, elementwise_max, elementwise_div
150150
"""
151-
# FIXME(wangxi): mp should prune duplicated param_grads
151+
is_clip_grad_by_global_norm = False
152+
for idx, op in list(enumerate(block.ops)):
153+
if not self._is_gradient_clip_op(op):
154+
continue
155+
if op.type == 'sum':
156+
is_clip_grad_by_global_norm = True
157+
break
158+
if not is_clip_grad_by_global_norm:
159+
# TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp
160+
return
161+
162+
removed_op_idx = set()
163+
removed_tmp_var = set()
164+
for idx, op in list(enumerate(block.ops)):
165+
if not self._is_gradient_clip_op(op):
166+
continue
167+
if op.type == 'sum':
168+
break
169+
for input_name in op.input_arg_names:
170+
input_var = block.var(input_name)
171+
# NOTE: when mp_degree > 1, some vars will be split into each mp rank.
172+
# However, there still some vars such as Scale, Bias are not split.
173+
# Those not be split vars should only be counted once during grad clip
174+
# by global norm. Those vars either doesn't have is_distributed attr
175+
# or the is_distributed attr has been set as False.
176+
# Therefore, we prune those duplicated vars for grad clip.
177+
if mp_rank >= 1 and (not (hasattr(input_var, 'is_distributed')
178+
and input_var.is_distributed)):
179+
removed_op_idx.add(idx)
180+
for output_name in op.output_arg_names:
181+
removed_tmp_var.add(output_name)
182+
152183
for idx, op in reversed(list(enumerate(block.ops))):
153184
if not self._is_gradient_clip_op(op):
154185
continue
186+
if idx in removed_op_idx:
187+
block._remove_op(idx, sync=False)
155188

156-
if op.type == "sum":
157-
sum_res = op.desc.output_arg_names()[0]
158-
for ring_id in ring_ids:
159-
if ring_id == -1: continue
189+
for var_name in removed_tmp_var:
190+
block._remove_var(var_name, sync=False)
160191

161-
idx = idx + 1
162-
block._insert_op_without_sync(
163-
idx,
164-
type='c_allreduce_sum',
165-
inputs={'X': sum_res},
166-
outputs={'Out': sum_res},
167-
attrs={
168-
'ring_id': ring_id,
169-
'op_namescope': "/gradient_clip_model_parallelism",
170-
'use_calc_stream': True,
171-
OP_ROLE_KEY: OpRole.Optimize,
172-
})
173-
return
192+
for idx, op in list(enumerate(block.ops)):
193+
if not self._is_gradient_clip_op(op):
194+
continue
195+
if op.type == 'sum':
196+
# If mp_rank == 0, no extra handles, just allreduce
197+
# If mp_rank >= 1, some extra handles is needed
198+
sum_rst_var = block.var(op.output_arg_names[0])
199+
if mp_rank >= 1:
200+
reserved_vars = []
201+
for input_name in op.input_arg_names:
202+
if input_name not in removed_tmp_var:
203+
reserved_vars.append(input_name)
204+
205+
if len(reserved_vars) > 0:
206+
op.desc.set_input("X", reserved_vars)
207+
else:
208+
# If all input of sum op should be removed, then remove the sum op.
209+
# And set the output's value of sum to 0.
210+
namescope = op.attr("op_namescope")
211+
block._remove_op(idx, sync=False)
212+
fill_constant_op = block._insert_op_without_sync(
213+
idx,
214+
type='fill_constant',
215+
inputs={},
216+
outputs={'Out': sum_rst_var},
217+
attrs={
218+
'shape': sum_rst_var.shape,
219+
'dtype': sum_rst_var.dtype,
220+
'value': 0.0,
221+
OP_ROLE_KEY: OpRole.Optimize
222+
})
223+
fill_constant_op._set_attr('op_namescope', namescope)
224+
self._insert_allreduce(block, ring_ids, idx, sum_rst_var)
225+
break
226+
227+
@staticmethod
228+
def _insert_allreduce(block, ring_ids, idx, var):
229+
for ring_id in ring_ids:
230+
if ring_id == -1:
231+
continue
232+
233+
idx = idx + 1
234+
block._insert_op_without_sync(
235+
idx,
236+
type='c_allreduce_sum',
237+
inputs={'X': var},
238+
outputs={'Out': var},
239+
attrs={
240+
'ring_id': ring_id,
241+
'op_namescope': "/gradient_clip_model_parallelism",
242+
'use_calc_stream': True,
243+
OP_ROLE_KEY: OpRole.Optimize,
244+
})

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,6 @@ def _adapt_amp_clip_without_sharding(self):
435435
main_block = self._main_program.global_block()
436436
startup_block = self._startup_program.global_block()
437437

438-
# FIXME(wangxi): mp should prune duplicated param_grads when calc
439438
# amp inf_var & clip global_norm_var
440439

441440
rings = [self.mp_ring_id, self.pp_ring_id]
@@ -446,7 +445,7 @@ def _adapt_amp_clip_without_sharding(self):
446445

447446
gradientclip_helper = GradientClipHelper(None)
448447
gradientclip_helper.sync_global_norm(
449-
main_block, [self.mp_ring_id, self.pp_ring_id])
448+
main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank)
450449

451450
def _insert_loss_grad_scale_op(self):
452451
main_block = self._main_program.global_block()

python/paddle/fluid/optimizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4381,7 +4381,7 @@ def _create_vars(self, block, ori_block):
43814381
persistable=source_var.persistable)
43824382
else:
43834383
dest_var = block._clone_variable(source_var, False)
4384-
dest_var.stop_gradient = source_var.stop_gradient
4384+
self._clone_var_attr(dest_var, source_var)
43854385
# When use with sharding, allreduce_sum and allreduce_max
43864386
# used for global gradient clip and amp will be added by sharding.
43874387
op_idx += 1
@@ -4547,9 +4547,14 @@ def _create_var(self, block, ref_var, name, dtype=None):
45474547
persistable=ref_var.persistable,
45484548
is_data=ref_var.is_data,
45494549
need_check_feed=ref_var.desc.need_check_feed())
4550-
new_var.stop_gradient = ref_var.stop_gradient
4550+
self._clone_var_attr(new_var, ref_var)
45514551
return new_var
45524552

4553+
def _clone_var_attr(self, dest, src):
4554+
dest.stop_gradient = src.stop_gradient
4555+
if hasattr(src, 'is_distributed'):
4556+
dest.is_distributed = src.is_distributed
4557+
45534558
def _strip_grad_suffix(self, name):
45544559
"""
45554560
Strip the grad suffix from the given variable name
@@ -5209,6 +5214,8 @@ def _insert_accumulate_gradients_with_fuse(self, main_block, fp16,
52095214
persistable=True,
52105215
stop_gradient=False)
52115216
real_param = main_block.var(param)
5217+
if hasattr(real_param, 'is_distributed'):
5218+
merged_grad_var.is_distributed = real_param.is_distributed
52125219
tmp_size = self._get_var_size(real_grad)
52135220
# two strategies for splitting the grad
52145221
# 1. the current segment's size reach the user defined grad_size_in_MB

python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,33 @@ def test_hybrid_with_mp_pp_amp_gclip(self):
658658
'c_gen_nccl_id', 'c_comm_init'
659659
])
660660

661+
self.assertEqual(main_prog_op_types, [
662+
'partial_recv', 'partial_allgather', 'cast', 'cast', 'mul', 'cast',
663+
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
664+
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
665+
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
666+
'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean',
667+
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
668+
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
669+
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
670+
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
671+
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
672+
'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream',
673+
'partial_send', 'fill_constant', 'cast', 'sum', 'fill_constant',
674+
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
675+
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
676+
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
677+
'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
678+
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
679+
'update_loss_scaling', 'fill_constant', 'c_allreduce_sum',
680+
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
681+
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
682+
'elementwise_mul', 'elementwise_mul', 'elementwise_mul',
683+
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum',
684+
'momentum', 'momentum', 'momentum', 'momentum', 'momentum',
685+
'momentum', 'momentum'
686+
])
687+
661688
# pp + mp, partial send recv
662689
self.assertIn('partial_recv', main_prog_op_types)
663690
self.assertIn('partial_allgather', main_prog_op_types)

0 commit comments

Comments
 (0)