1818import unittest
1919
2020from fleet_meta_optimizer_base import TestFleetMetaOptimizer
21+ from paddle .distributed .fleet .meta_optimizers .common import is_loss_grad_op
2122
2223paddle .enable_static ()
2324
@@ -77,10 +78,10 @@ def test_opt_sharding_with_pp(self):
7778 'recv_v2' , 'mul' , 'elementwise_add' , 'tanh' , 'mul' ,
7879 'elementwise_add' , 'tanh' , 'mul' , 'elementwise_add' , 'tanh' , 'mul' ,
7980 'elementwise_add' , 'softmax' , 'cross_entropy2' , 'mean' ,
80- 'fill_constant' , 'scale ' , 'scale ' , 'mean_grad ' ,
81- 'cross_entropy_grad2 ' , 'softmax_grad ' , 'elementwise_add_grad ' ,
82- 'mul_grad ' , 'tanh_grad ' , 'elementwise_add_grad' , 'mul_grad ' ,
83- 'tanh_grad' , ' elementwise_add_grad' , 'mul_grad' , 'tanh_grad' ,
81+ 'fill_constant' , 'mean_grad ' , 'cross_entropy_grad2 ' , 'softmax_grad ' ,
82+ 'elementwise_add_grad ' , 'mul_grad ' , 'tanh_grad ' ,
83+ 'elementwise_add_grad ' , 'mul_grad ' , 'tanh_grad ' ,
84+ 'elementwise_add_grad' , 'mul_grad' , 'tanh_grad' ,
8485 'elementwise_add_grad' , 'mul_grad' , 'c_sync_calc_stream' , 'send_v2' ,
8586 'fill_constant' , 'sum' , 'fill_constant' , 'sum' , 'fill_constant' ,
8687 'sum' , 'fill_constant' , 'sum' , 'fill_constant' , 'sum' ,
@@ -158,10 +159,10 @@ def test_opt_sharding_with_pp_with_allreduce_fuse(self):
158159 'recv_v2' , 'mul' , 'elementwise_add' , 'tanh' , 'mul' ,
159160 'elementwise_add' , 'tanh' , 'mul' , 'elementwise_add' , 'tanh' , 'mul' ,
160161 'elementwise_add' , 'softmax' , 'cross_entropy2' , 'mean' ,
161- 'fill_constant' , 'scale ' , 'scale ' , 'mean_grad ' ,
162- 'cross_entropy_grad2 ' , 'softmax_grad ' , 'elementwise_add_grad ' ,
163- 'mul_grad ' , 'tanh_grad ' , 'elementwise_add_grad' , 'mul_grad ' ,
164- 'tanh_grad' , ' elementwise_add_grad' , 'mul_grad' , 'tanh_grad' ,
162+ 'fill_constant' , 'mean_grad ' , 'cross_entropy_grad2 ' , 'softmax_grad ' ,
163+ 'elementwise_add_grad ' , 'mul_grad ' , 'tanh_grad ' ,
164+ 'elementwise_add_grad ' , 'mul_grad ' , 'tanh_grad ' ,
165+ 'elementwise_add_grad' , 'mul_grad' , 'tanh_grad' ,
165166 'elementwise_add_grad' , 'mul_grad' , 'c_sync_calc_stream' , 'send_v2' ,
166167 'fill_constant' , 'sum' , 'fill_constant' , 'sum' , 'fill_constant' ,
167168 'sum' , 'fill_constant' , 'sum' , 'fill_constant' , 'sum' ,
@@ -220,8 +221,8 @@ def test_opt_sharding_with_pp_amp_gclip(self):
220221 'tanh' , 'cast' , 'cast' , 'mul' , 'cast' , 'elementwise_add' , 'cast' ,
221222 'tanh' , 'cast' , 'cast' , 'mul' , 'cast' , 'elementwise_add' , 'softmax' ,
222223 'cast' , 'cross_entropy2' , 'mean' , 'elementwise_mul' ,
223- 'fill_constant' , 'scale ' , 'scale' , 'elementwise_mul_grad ' ,
224- 'mean_grad' , ' cross_entropy_grad2' , 'cast' , 'softmax_grad' ,
224+ 'fill_constant' , 'elementwise_mul_grad ' , 'mean_grad ' ,
225+ 'cross_entropy_grad2' , 'cast' , 'softmax_grad' ,
225226 'elementwise_add_grad' , 'mul_grad' , 'cast' , 'tanh_grad' , 'cast' ,
226227 'elementwise_add_grad' , 'mul_grad' , 'cast' , 'tanh_grad' , 'cast' ,
227228 'elementwise_add_grad' , 'mul_grad' , 'cast' , 'tanh_grad' , 'cast' ,
@@ -293,23 +294,23 @@ def test_opt_sharding_with_pp_amp_gclip_fuse_gm(self):
293294 'tanh' , 'cast' , 'cast' , 'mul' , 'cast' , 'elementwise_add' , 'softmax' ,
294295 'cast' , 'cross_entropy2' , 'mean' , 'elementwise_mul' ,
295296 'coalesce_tensor' , 'coalesce_tensor' , 'coalesce_tensor' ,
296- 'coalesce_tensor' , 'fill_constant' , 'scale' , 'scale ' ,
297- 'elementwise_mul_grad ' , 'mean_grad ' , 'cross_entropy_grad2 ' , 'cast ' ,
298- 'softmax_grad ' , 'elementwise_add_grad ' , 'mul_grad ' , 'cast' ,
299- 'tanh_grad ' , 'cast ' , 'elementwise_add_grad ' , 'mul_grad ' , 'cast' ,
300- 'tanh_grad ' , 'cast ' , 'elementwise_add_grad ' , 'mul_grad ' , 'cast' ,
301- 'tanh_grad ' , 'cast ' , 'elementwise_add_grad ' , 'mul_grad' , 'cast ' ,
302- 'c_sync_calc_stream' , ' send_v2' , 'cast' , 'sum' , 'cast' , 'sum' ,
303- 'c_reduce_sum' , 'c_reduce_sum ' , 'c_sync_comm_stream ' ,
304- 'check_finite_and_unscale ' , 'cast ' , 'c_allreduce_max' ,
305- 'c_allreduce_max ' , 'cast' , 'update_loss_scaling ' , 'squared_l2_norm' ,
306- 'squared_l2_norm' , 'squared_l2_norm' , 'squared_l2_norm' ,
307- 'squared_l2_norm ' , 'sum ' , 'c_allreduce_sum ' , 'c_allreduce_sum ' ,
308- 'sqrt ' , 'fill_constant ' , 'elementwise_max' , 'elementwise_div ' ,
297+ 'coalesce_tensor' , 'fill_constant' , 'elementwise_mul_grad ' ,
298+ 'mean_grad ' , 'cross_entropy_grad2 ' , 'cast ' , 'softmax_grad ' ,
299+ 'elementwise_add_grad ' , 'mul_grad ' , 'cast' , 'tanh_grad ' , 'cast' ,
300+ 'elementwise_add_grad ' , 'mul_grad ' , 'cast ' , 'tanh_grad ' , 'cast' ,
301+ 'elementwise_add_grad ' , 'mul_grad ' , 'cast ' , 'tanh_grad ' , 'cast' ,
302+ 'elementwise_add_grad ' , 'mul_grad ' , 'cast ' , 'c_sync_calc_stream ' ,
303+ 'send_v2' , 'cast' , 'sum' , 'cast' , 'sum' , 'c_reduce_sum ' ,
304+ 'c_reduce_sum' , 'c_sync_comm_stream ' , 'check_finite_and_unscale ' ,
305+ 'cast ' , 'c_allreduce_max ' , 'c_allreduce_max' , 'cast ' ,
306+ 'update_loss_scaling ' , 'squared_l2_norm ' , 'squared_l2_norm' ,
307+ 'squared_l2_norm' , 'squared_l2_norm' , 'squared_l2_norm' , 'sum' ,
308+ 'c_allreduce_sum ' , 'c_allreduce_sum ' , 'sqrt ' , 'fill_constant ' ,
309+ 'elementwise_max ' , 'elementwise_div ' , 'elementwise_mul ' ,
309310 'elementwise_mul' , 'elementwise_mul' , 'elementwise_mul' ,
310- 'elementwise_mul' , 'elementwise_mul ' , 'momentum' , 'momentum' ,
311- 'momentum' , 'momentum ' , 'momentum ' , 'coalesce_tensor' ,
312- 'c_broadcast' , 'coalesce_tensor' , 'c_broadcast'
311+ 'elementwise_mul' , 'momentum' , 'momentum ' , 'momentum' , 'momentum' ,
312+ 'momentum' , 'coalesce_tensor ' , 'c_broadcast ' , 'coalesce_tensor' ,
313+ 'c_broadcast'
313314 ])
314315
315316
@@ -327,7 +328,10 @@ def setUp(self):
327328 self ._debug = False
328329
329330 def test_opt_sharding_with_pp_amp_gclip_boundary (self ):
330- """ test optimizer sharding without parameter """
331+ """
332+ test optimizer sharding without parameter
333+ test loss grad scale value
334+ """
331335 train_prog , startup_prog = static .Program (), static .Program ()
332336 avg_cost , strategy = self .boundary_net (train_prog , startup_prog )
333337
@@ -357,6 +361,16 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):
357361 startup_prog_op_types = [op .type for op in startup_prog_ops ]
358362 main_prog_op_types = [op .type for op in main_prog_ops ]
359363
364+ # check loss scale for hybrid
365+ for op in main_prog_ops :
366+ if is_loss_grad_op (op ):
367+ self .assertEqual (op .type , 'fill_constant' )
368+ self .assertTrue (op .has_attr ('value' ))
369+ scale = strategy .pipeline_configs [
370+ 'accumulate_steps' ] * strategy .sharding_configs ['dp_degree' ]
371+ loss_scale = 1.0 / scale
372+ self .assertAlmostEqual (float (op .attr ('value' )), loss_scale )
373+
360374 # global, sharding, pp_send, pp_recv
361375 self .assertEqual (startup_prog_op_types , [
362376 'uniform_random' , 'fill_constant' , 'fill_constant' , 'fill_constant' ,
@@ -367,14 +381,13 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):
367381
368382 self .assertEqual (main_prog_op_types , [
369383 'recv_v2' , 'cast' , 'matmul' , 'cast' , 'reduce_mean' ,
370- 'elementwise_mul' , 'fill_constant' , 'scale' , 'scale' ,
371- 'elementwise_mul_grad' , 'reduce_mean_grad' , 'cast' , 'matmul_grad' ,
372- 'c_sync_calc_stream' , 'send_v2' , 'fill_constant' , 'cast' , 'sum' ,
373- 'c_reduce_sum' , 'c_sync_comm_stream' , 'check_finite_and_unscale' ,
374- 'cast' , 'c_allreduce_max' , 'c_allreduce_max' , 'cast' ,
375- 'update_loss_scaling' , 'fill_constant' , 'c_allreduce_sum' ,
376- 'c_allreduce_sum' , 'sqrt' , 'fill_constant' , 'elementwise_max' ,
377- 'elementwise_div' , 'c_broadcast'
384+ 'elementwise_mul' , 'fill_constant' , 'elementwise_mul_grad' ,
385+ 'reduce_mean_grad' , 'cast' , 'matmul_grad' , 'c_sync_calc_stream' ,
386+ 'send_v2' , 'fill_constant' , 'cast' , 'sum' , 'c_reduce_sum' ,
387+ 'c_sync_comm_stream' , 'check_finite_and_unscale' , 'cast' ,
388+ 'c_allreduce_max' , 'c_allreduce_max' , 'cast' , 'update_loss_scaling' ,
389+ 'fill_constant' , 'c_allreduce_sum' , 'c_allreduce_sum' , 'sqrt' ,
390+ 'fill_constant' , 'elementwise_max' , 'elementwise_div' , 'c_broadcast'
378391 ])
379392
380393 def test_opt_sharding_with_pp_amp_gclip_boundary_card1 (self ):
@@ -419,14 +432,14 @@ def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self):
419432
420433 self .assertEqual (main_prog_op_types , [
421434 'recv_v2' , 'cast' , 'matmul' , 'cast' , 'reduce_mean' ,
422- 'elementwise_mul' , 'fill_constant' , 'scale' , 'scale ' ,
423- 'elementwise_mul_grad ' , 'reduce_mean_grad ' , 'cast ' , 'matmul_grad ' ,
424- 'c_sync_calc_stream' , ' send_v2' , 'fill_constant' , 'cast' , 'sum' ,
425- 'c_reduce_sum ' , 'c_sync_comm_stream ' , 'check_finite_and_unscale ' ,
426- 'cast ' , 'c_allreduce_max' , 'c_allreduce_max ' , 'cast ' ,
427- 'update_loss_scaling ' , 'squared_l2_norm ' , 'sum ' , 'c_allreduce_sum' ,
428- 'c_allreduce_sum ' , 'sqrt ' , 'fill_constant ' , 'elementwise_max ' ,
429- 'elementwise_div' , ' elementwise_mul' , 'momentum' , 'c_broadcast'
435+ 'elementwise_mul' , 'fill_constant' , 'elementwise_mul_grad ' ,
436+ 'reduce_mean_grad ' , 'cast ' , 'matmul_grad ' , 'c_sync_calc_stream ' ,
437+ 'send_v2' , 'fill_constant' , 'cast' , 'sum' , 'c_reduce_sum ' ,
438+ 'c_sync_comm_stream ' , 'check_finite_and_unscale ' , 'cast ' ,
439+ 'c_allreduce_max ' , 'c_allreduce_max' , 'cast ' , 'update_loss_scaling ' ,
440+ 'squared_l2_norm ' , 'sum ' , 'c_allreduce_sum ' , 'c_allreduce_sum' ,
441+ 'sqrt ' , 'fill_constant ' , 'elementwise_max ' , 'elementwise_div ' ,
442+ 'elementwise_mul' , 'momentum' , 'c_broadcast'
430443 ])
431444
432445
0 commit comments