diff --git a/python/paddle/nn/clip.py b/python/paddle/nn/clip.py index b3fe014b27a350..0d650d8fed519e 100644 --- a/python/paddle/nn/clip.py +++ b/python/paddle/nn/clip.py @@ -717,6 +717,7 @@ def _dygraph_clip(self, params_grads): sum_square_list = [] sum_square_list_fp16 = [] sum_square_list_fp32 = [] + flag_auto_hybrid_pp = True # Determine whether to use the new dynamic graph semi-automatic parallel pp framework if len(params_grads) > 0 and len(params_grads[0]) > 0: src_mesh = params_grads[0][0].process_mesh else: @@ -742,6 +743,7 @@ def _dygraph_clip(self, params_grads): # if the gradient mesh is not equal to src mesh # do reshard to get the result of squared_l2 from other pp stage mesh if src_mesh is not None and g.process_mesh != src_mesh: + flag_auto_hybrid_pp = False pp_mesh = get_complete_pp_mesh(g.process_mesh) if set(g.process_mesh.process_ids) < set(pp_mesh.process_ids): sum_square = dist.reshard( @@ -791,6 +793,37 @@ def async_add_n(var_list): global_norm_var = async_add_n(global_norm_var) + # NOTE(zhengtianyu): Fix grad_clip in auto_hybrid_pp mode. + # Reason: In auto_hybrid_pp mode, each rank only keeps local parameters and gradient information, + # so global_norm_var is in a partial state, leading to incorrect calculation. + # Reference dynamic manual-parallel: Each rank computes local global_norm_var, + # then performs pp group communication reduce(sum) to get correct global_norm_var. + # For complete alignment with old dygraph semi-auto parallel PP logic, + # refer to NOTE: align ClipGradByGlobalNorm in auto_parallel_align_mode + if flag_auto_hybrid_pp and src_mesh is not None: + g_mesh = dist.get_mesh() + if ( + g_mesh + and "pp" in g_mesh.dim_names + and g_mesh.get_dim_size("pp") > 1 + ): + # Get the pipeline parallelism subgroup for communication + pp_group = g_mesh.get_submesh_with_dim("pp").get_group("pp") + + # Perform all-reduce on the local tensor value across the PP group + global_norm_var_local = global_norm_var._local_value() + dist.all_reduce( + global_norm_var_local, + op=dist.ReduceOp.SUM, + group=pp_group, + ) + + global_norm_var = dist.shard_tensor( + global_norm_var_local, + global_norm_var.process_mesh, + global_norm_var.placements, + ) + if self.should_comm_on_shard_dim and hasattr(self, 'sharding_group'): paddle.distributed.all_reduce( global_norm_var._local_value(), group=self.sharding_group diff --git a/test/auto_parallel/PP_Schedules_demo.py b/test/auto_parallel/PP_Schedules_demo.py index 6ac055410fbf0a..be8963356d0661 100644 --- a/test/auto_parallel/PP_Schedules_demo.py +++ b/test/auto_parallel/PP_Schedules_demo.py @@ -414,6 +414,67 @@ def test_dp_pp(self): opt.clear_grad() return losses_by_step, all_losses_in_one_step_md5sum + def test_pp_model_with_ClipGradByGlobalNorm(self): + """Test pipeline parallel model with ClipGradByGlobalNorm using PPMyModel as the baseline""" + fix_seeds() + pp_model = PPMyModel() + opt = paddle.optimizer.AdamW( + learning_rate=0.001, + parameters=pp_model.parameters(), + grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0), + ) + loss_fn = nn.MSELoss() + dataset = RandomDataset(image_size=8, output_size=8, num_samples=8) + loader = DataLoader(dataset, batch_size=1) + pp_losses_step = [] + num_iterations = 20 + + for iter_idx in range(num_iterations): + pp_losses_micro_batch = [] + for i, (data, label) in enumerate(loader): + output = pp_model(data) + loss = loss_fn(output, label) + pp_losses_micro_batch.append(loss.item()) + loss.backward() + pp_losses_step.append( + np.array(pp_losses_micro_batch, dtype=np.float32).mean() + ) + opt.step() + opt.clear_grad() + return pp_losses_step + + def test_ScheduleFThenB_with_ClipGradByGlobalNorm(self): + fix_seeds() + self.model = PPMyModel_SingleStage() + self.micro_batches = 8 + self.stage = PipelineStage(self.model, self.rank, 4, group=self.group) + self.stage.has_backward = True + loss_fn_ = nn.MSELoss() + schedule = ScheduleFThenB( + self.stage, self.micro_batches, loss_fn=loss_fn_ + ) + opt = paddle.optimizer.AdamW( + learning_rate=0.001, + parameters=self.model.parameters(), + grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0), + ) + dataset = RandomDataset(image_size=8, output_size=8, num_samples=8) + loader = DataLoader(dataset, batch_size=8) + losses_by_step = [] + num_iterations = 20 + + for iter_idx in range(num_iterations): + losses_by_micro_batch = [] + for i, (data, label) in enumerate(loader): + schedule.step(data, target=label, losses=losses_by_micro_batch) + if self.rank == 3: + losses_by_step.append( + np.array(losses_by_micro_batch, dtype=np.float32).mean() + ) + opt.step() + opt.clear_grad() + return losses_by_step + def test_dp_pp_align_mode(self): fix_seeds() paddle.set_flags({'FLAGS_enable_auto_parallel_align_mode': True}) @@ -490,6 +551,12 @@ def run_test(self): scheduleFThenB_losses = self.test_ScheduleFThenB() schedule1f1b_losses = self.test_Schedule1F1B() schedulevpp_losses = self.test_ScheduleVPP() + pp_model_with_ClipGradByGlobalNorm_losses = ( + self.test_pp_model_with_ClipGradByGlobalNorm() + ) + scheduleFThenB_with_ClipGradByGlobalNorm_losses = ( + self.test_ScheduleFThenB_with_ClipGradByGlobalNorm() + ) dp_pp_losses, dp_pp_losses_md5sum = self.test_dp_pp() dp_pp_align_mode_losses, dp_pp_align_mode_losses_md5sum = ( self.test_dp_pp_align_mode() @@ -520,6 +587,12 @@ def run_test(self): rtol=1e-5, ) + np.testing.assert_allclose( + pp_model_with_ClipGradByGlobalNorm_losses, + scheduleFThenB_with_ClipGradByGlobalNorm_losses, + rtol=1e-5, + ) + np.testing.assert_allclose( dp_pp_align_mode_losses, dp_pp_losses,