@@ -293,19 +293,17 @@ def _dygraph_clip(self, params_grads):
293293 params_grads , global_norm_var_dist , global_norm_var_not_dist
294294 )
295295
296- def _comm_and_clip (
297- self , params_grads , global_norm_var_dist , global_norm_var_not_dist
298- ):
299- # sharding first
300- sharding_flag = (
301- self ._hcg .get_sharding_parallel_world_size () > 1
302- and self ._hcg .get_data_parallel_world_size () == 1
303- )
296+ def _global_norm (self , global_norm_var_dist , global_norm_var_not_dist ):
297+ sharding_flag = self ._hcg .get_sharding_parallel_world_size () > 1
298+ dp_flag = self ._hcg .get_data_parallel_world_size () > 1
304299 mp_flag = self ._hcg .get_model_parallel_world_size () > 1
305- # add all reduce to get global norm of distributed params_and_grads
300+ pp_flag = self ._hcg .get_pipe_parallel_world_size () > 1
301+
302+ # not g_shard_norm_align_dp, grads are sharded among sharding group
306303 if sharding_flag and not g_shard_norm_align_dp :
307304 # norm of mp distributed variable
308305 if mp_flag :
306+ # dist should reduce among sharding group and mp group、pp group latter
309307 paddle .distributed .all_reduce (
310308 global_norm_var_dist ,
311309 group = self ._hcg .get_sharding_parallel_group (),
@@ -315,21 +313,40 @@ def _comm_and_clip(
315313 global_norm_var_not_dist ,
316314 group = self ._hcg .get_sharding_parallel_group (),
317315 )
316+
318317 # norm of mp distributed variable
319318 if mp_flag :
320- # dist should reduce among sharding group、mp group、pp group
321- paddle .distributed .all_reduce (
322- global_norm_var_dist ,
323- group = self ._hcg .get_check_parallel_group (sharding_flag ),
324- )
319+ # the else branch would suffice, but this branch remains here for number precision backward compatibility
320+ if not (dp_flag and sharding_flag ):
321+ paddle .distributed .all_reduce (
322+ global_norm_var_dist ,
323+ group = self ._hcg .get_check_parallel_group (sharding_flag ),
324+ )
325+ else :
326+ # global_norm_var_dist should all reduce among model parallel group and pp group
327+ paddle .distributed .all_reduce (
328+ global_norm_var_dist ,
329+ group = self ._hcg .get_model_parallel_group (),
330+ )
331+ if pp_flag :
332+ paddle .distributed .all_reduce (
333+ global_norm_var_dist ,
334+ group = self ._hcg .get_pipe_parallel_group (),
335+ )
325336
326337 # add all reduce to get global norm of non-distributed params_and_grads in groups of pp
327- if self . _hcg . get_pipe_parallel_world_size () > 1 :
338+ if pp_flag :
328339 paddle .distributed .all_reduce (
329340 global_norm_var_not_dist ,
330341 group = self ._hcg .get_pipe_parallel_group (),
331342 )
332343
344+ def _comm_and_clip (
345+ self , params_grads , global_norm_var_dist , global_norm_var_not_dist
346+ ):
347+
348+ self ._global_norm (global_norm_var_dist , global_norm_var_not_dist )
349+
333350 global_norm_var_fp32 = paddle .sqrt (
334351 global_norm_var_dist + global_norm_var_not_dist
335352 )
@@ -554,15 +571,21 @@ def _step(self, parameters_list):
554571 @no_grad ()
555572 @framework .dygraph_only
556573 def step (self ):
557- parameters_list = obtain_optimizer_parameters_list (self ._inner_opt )
574+ parameter_list = list (obtain_optimizer_parameters_list (self ._inner_opt ))
575+ dp_parameter_list = parameter_list
558576 if self ._sharding_enable :
559577 assert isinstance (self ._inner_opt , DygraphShardingOptimizer )
560- self ._inner_opt .reduce_gradients (list (parameters_list ), self ._hcg )
578+ self ._inner_opt .reduce_gradients (parameter_list , self ._hcg )
579+ # dp sync later do not need to use global parameter list
580+ if not g_shard_norm_align_dp :
581+ dp_parameter_list = self ._inner_opt .filter_parameters (
582+ parameter_list , self ._hcg
583+ )
561584
562585 if self ._dp_enable :
563- fused_allreduce_gradients (list ( parameters_list ) , self ._hcg )
586+ fused_allreduce_gradients (dp_parameter_list , self ._hcg )
564587
565- self ._step (parameters_list )
588+ self ._step (parameter_list )
566589
567590 @no_grad ()
568591 def minimize (
@@ -574,14 +597,20 @@ def minimize(
574597 parameter_list = (
575598 parameters if parameters else self ._inner_opt ._parameter_list
576599 )
577-
600+ parameter_list = list (parameter_list )
601+ dp_parameter_list = parameter_list
578602 # Here sharding should use global parameter list
579603 if self ._sharding_enable :
580604 assert isinstance (self ._inner_opt , DygraphShardingOptimizer )
581- self ._inner_opt .reduce_gradients (list (parameter_list ), self ._hcg )
605+ self ._inner_opt .reduce_gradients (parameter_list , self ._hcg )
606+ # dp sync later do not need to use global parameter list
607+ if not g_shard_norm_align_dp :
608+ dp_parameter_list = self ._inner_opt .filter_parameters (
609+ parameter_list , self ._hcg
610+ )
582611
583612 if self ._dp_enable :
584- fused_allreduce_gradients (list ( parameter_list ) , self ._hcg )
613+ fused_allreduce_gradients (dp_parameter_list , self ._hcg )
585614
586615 return self ._inner_opt .minimize (
587616 loss , startup_program , parameter_list , no_grad_set
0 commit comments