Skip to content

Commit 069e32e

Browse files
wangxicodingAnnaTrainingG
authored andcommitted
[hybrid] remove scale op in insert_scale_loss_grad_ops (PaddlePaddle#35775)
1 parent 0f9b0fe commit 069e32e

File tree

5 files changed

+125
-116
lines changed

5 files changed

+125
-116
lines changed

python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -738,14 +738,13 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
738738
'''
739739
for idx, op in reversed(list(enumerate(block.ops))):
740740
if is_loss_grad_op(op):
741-
loss_grad_var = block.vars[op.output_arg_names[0]]
742-
block._insert_op_without_sync(
743-
idx + 1,
744-
type='scale',
745-
inputs={'X': loss_grad_var},
746-
outputs={'Out': loss_grad_var},
747-
attrs={'scale': scale,
748-
OP_ROLE_KEY: OpRole.Backward})
741+
assert op.type == 'fill_constant', \
742+
"loss_grad_op must be fill_constant op, " \
743+
"but this op is {}".format(op.type)
744+
assert op.has_attr('value')
745+
loss_scale = float(op.attr('value'))
746+
loss_scale = loss_scale / scale
747+
op._set_attr('value', loss_scale)
749748
break
750749

751750

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def _insert_loss_grad_scale_op(self):
455455
global_dp_degree = self.sharding_degree * self.dp_degree
456456
assert int(global_dp_degree) == global_dp_degree
457457
if global_dp_degree > 1:
458-
insert_scale_loss_grad_ops(main_block, scale=1.0 / global_dp_degree)
458+
insert_scale_loss_grad_ops(main_block, scale=global_dp_degree)
459459

460460
main_block._sync_with_cpp()
461461

python/paddle/fluid/optimizer.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5019,16 +5019,13 @@ def _insert_loss_scale(self, block):
50195019
if self._num_microbatches == 1: return
50205020
for index, op in reversed(tuple(enumerate(list(block.ops)))):
50215021
if self._is_loss_grad_op(op):
5022-
loss_grad_var = block.vars[op.output_arg_names[0]]
5023-
block._insert_op(
5024-
index=index + 1,
5025-
type='scale',
5026-
inputs={'X': loss_grad_var},
5027-
outputs={'Out': loss_grad_var},
5028-
attrs={
5029-
'scale': 1.0 / self._num_microbatches,
5030-
self._op_role_key: self._op_role.Backward
5031-
})
5022+
assert op.type == 'fill_constant', \
5023+
"loss_grad_op must be fill_constant op, " \
5024+
"but this op is {}".format(op.type)
5025+
assert op.has_attr('value')
5026+
loss_scale = float(op.attr('value'))
5027+
loss_scale = loss_scale / self._num_microbatches
5028+
op._set_attr('value', loss_scale)
50325029
break
50335030

50345031
def _rename_gradient_var_name(self, block):

python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import unittest
1919

2020
from fleet_meta_optimizer_base import TestFleetMetaOptimizer
21+
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
2122

2223
paddle.enable_static()
2324

@@ -77,10 +78,10 @@ def test_opt_sharding_with_pp(self):
7778
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
7879
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
7980
'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
80-
'fill_constant', 'scale', 'scale', 'mean_grad',
81-
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
82-
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
83-
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
81+
'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad',
82+
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
83+
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
84+
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
8485
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
8586
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
8687
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
@@ -158,10 +159,10 @@ def test_opt_sharding_with_pp_with_allreduce_fuse(self):
158159
'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul',
159160
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul',
160161
'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
161-
'fill_constant', 'scale', 'scale', 'mean_grad',
162-
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
163-
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
164-
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
162+
'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad',
163+
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
164+
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
165+
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
165166
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2',
166167
'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant',
167168
'sum', 'fill_constant', 'sum', 'fill_constant', 'sum',
@@ -220,8 +221,8 @@ def test_opt_sharding_with_pp_amp_gclip(self):
220221
'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast',
221222
'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax',
222223
'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
223-
'fill_constant', 'scale', 'scale', 'elementwise_mul_grad',
224-
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
224+
'fill_constant', 'elementwise_mul_grad', 'mean_grad',
225+
'cross_entropy_grad2', 'cast', 'softmax_grad',
225226
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
226227
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
227228
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
@@ -293,23 +294,23 @@ def test_opt_sharding_with_pp_amp_gclip_fuse_gm(self):
293294
'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax',
294295
'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
295296
'coalesce_tensor', 'coalesce_tensor', 'coalesce_tensor',
296-
'coalesce_tensor', 'fill_constant', 'scale', 'scale',
297-
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast',
298-
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast',
299-
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
300-
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
301-
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
302-
'c_sync_calc_stream', 'send_v2', 'cast', 'sum', 'cast', 'sum',
303-
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
304-
'check_finite_and_unscale', 'cast', 'c_allreduce_max',
305-
'c_allreduce_max', 'cast', 'update_loss_scaling', 'squared_l2_norm',
306-
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm',
307-
'squared_l2_norm', 'sum', 'c_allreduce_sum', 'c_allreduce_sum',
308-
'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
297+
'coalesce_tensor', 'fill_constant', 'elementwise_mul_grad',
298+
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
299+
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
300+
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
301+
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
302+
'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream',
303+
'send_v2', 'cast', 'sum', 'cast', 'sum', 'c_reduce_sum',
304+
'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
305+
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
306+
'update_loss_scaling', 'squared_l2_norm', 'squared_l2_norm',
307+
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
308+
'c_allreduce_sum', 'c_allreduce_sum', 'sqrt', 'fill_constant',
309+
'elementwise_max', 'elementwise_div', 'elementwise_mul',
309310
'elementwise_mul', 'elementwise_mul', 'elementwise_mul',
310-
'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum',
311-
'momentum', 'momentum', 'momentum', 'coalesce_tensor',
312-
'c_broadcast', 'coalesce_tensor', 'c_broadcast'
311+
'elementwise_mul', 'momentum', 'momentum', 'momentum', 'momentum',
312+
'momentum', 'coalesce_tensor', 'c_broadcast', 'coalesce_tensor',
313+
'c_broadcast'
313314
])
314315

315316

@@ -327,7 +328,10 @@ def setUp(self):
327328
self._debug = False
328329

329330
def test_opt_sharding_with_pp_amp_gclip_boundary(self):
330-
""" test optimizer sharding without parameter """
331+
"""
332+
test optimizer sharding without parameter
333+
test loss grad scale value
334+
"""
331335
train_prog, startup_prog = static.Program(), static.Program()
332336
avg_cost, strategy = self.boundary_net(train_prog, startup_prog)
333337

@@ -357,6 +361,16 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):
357361
startup_prog_op_types = [op.type for op in startup_prog_ops]
358362
main_prog_op_types = [op.type for op in main_prog_ops]
359363

364+
# check loss scale for hybrid
365+
for op in main_prog_ops:
366+
if is_loss_grad_op(op):
367+
self.assertEqual(op.type, 'fill_constant')
368+
self.assertTrue(op.has_attr('value'))
369+
scale = strategy.pipeline_configs[
370+
'accumulate_steps'] * strategy.sharding_configs['dp_degree']
371+
loss_scale = 1.0 / scale
372+
self.assertAlmostEqual(float(op.attr('value')), loss_scale)
373+
360374
# global, sharding, pp_send, pp_recv
361375
self.assertEqual(startup_prog_op_types, [
362376
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
@@ -367,14 +381,13 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):
367381

368382
self.assertEqual(main_prog_op_types, [
369383
'recv_v2', 'cast', 'matmul', 'cast', 'reduce_mean',
370-
'elementwise_mul', 'fill_constant', 'scale', 'scale',
371-
'elementwise_mul_grad', 'reduce_mean_grad', 'cast', 'matmul_grad',
372-
'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum',
373-
'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
374-
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
375-
'update_loss_scaling', 'fill_constant', 'c_allreduce_sum',
376-
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
377-
'elementwise_div', 'c_broadcast'
384+
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
385+
'reduce_mean_grad', 'cast', 'matmul_grad', 'c_sync_calc_stream',
386+
'send_v2', 'fill_constant', 'cast', 'sum', 'c_reduce_sum',
387+
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
388+
'c_allreduce_max', 'c_allreduce_max', 'cast', 'update_loss_scaling',
389+
'fill_constant', 'c_allreduce_sum', 'c_allreduce_sum', 'sqrt',
390+
'fill_constant', 'elementwise_max', 'elementwise_div', 'c_broadcast'
378391
])
379392

380393
def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self):
@@ -419,14 +432,14 @@ def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self):
419432

420433
self.assertEqual(main_prog_op_types, [
421434
'recv_v2', 'cast', 'matmul', 'cast', 'reduce_mean',
422-
'elementwise_mul', 'fill_constant', 'scale', 'scale',
423-
'elementwise_mul_grad', 'reduce_mean_grad', 'cast', 'matmul_grad',
424-
'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum',
425-
'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
426-
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
427-
'update_loss_scaling', 'squared_l2_norm', 'sum', 'c_allreduce_sum',
428-
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
429-
'elementwise_div', 'elementwise_mul', 'momentum', 'c_broadcast'
435+
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
436+
'reduce_mean_grad', 'cast', 'matmul_grad', 'c_sync_calc_stream',
437+
'send_v2', 'fill_constant', 'cast', 'sum', 'c_reduce_sum',
438+
'c_sync_comm_stream', 'check_finite_and_unscale', 'cast',
439+
'c_allreduce_max', 'c_allreduce_max', 'cast', 'update_loss_scaling',
440+
'squared_l2_norm', 'sum', 'c_allreduce_sum', 'c_allreduce_sum',
441+
'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
442+
'elementwise_mul', 'momentum', 'c_broadcast'
430443
])
431444

432445

0 commit comments

Comments
 (0)