Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions python/paddle/nn/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def _dygraph_clip(self, params_grads):
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
flag_auto_hybrid_pp = True # Determine whether to use the new dynamic graph semi-automatic parallel pp framework
if len(params_grads) > 0 and len(params_grads[0]) > 0:
src_mesh = params_grads[0][0].process_mesh
else:
Expand All @@ -742,6 +743,7 @@ def _dygraph_clip(self, params_grads):
# if the gradient mesh is not equal to src mesh
# do reshard to get the result of squared_l2 from other pp stage mesh
if src_mesh is not None and g.process_mesh != src_mesh:
flag_auto_hybrid_pp = False
pp_mesh = get_complete_pp_mesh(g.process_mesh)
if set(g.process_mesh.process_ids) < set(pp_mesh.process_ids):
sum_square = dist.reshard(
Expand Down Expand Up @@ -791,6 +793,37 @@ def async_add_n(var_list):

global_norm_var = async_add_n(global_norm_var)

# NOTE(zhengtianyu): Fix grad_clip in auto_hybrid_pp mode.
# Reason: In auto_hybrid_pp mode, each rank only keeps local parameters and gradient information,
# so global_norm_var is in a partial state, leading to incorrect calculation.
# Reference dynamic manual-parallel: Each rank computes local global_norm_var,
# then performs pp group communication reduce(sum) to get correct global_norm_var.
# For complete alignment with old dygraph semi-auto parallel PP logic,
# refer to NOTE: align ClipGradByGlobalNorm in auto_parallel_align_mode
if flag_auto_hybrid_pp and src_mesh is not None:
g_mesh = dist.get_mesh()
if (
g_mesh
and "pp" in g_mesh.dim_names
and g_mesh.get_dim_size("pp") > 1
):
# Get the pipeline parallelism subgroup for communication
pp_group = g_mesh.get_submesh_with_dim("pp").get_group("pp")

# Perform all-reduce on the local tensor value across the PP group
global_norm_var_local = global_norm_var._local_value()
dist.all_reduce(
global_norm_var_local,
op=dist.ReduceOp.SUM,
group=pp_group,
)

global_norm_var = dist.shard_tensor(
global_norm_var_local,
global_norm_var.process_mesh,
global_norm_var.placements,
)

Copy link
Contributor

@xuxinyi389 xuxinyi389 Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉 is_pp_enable 逻辑不够简洁,可以改成下方代码:

# Check for auto hybrid pipeline parallelism and source mesh existence
if flag_auto_hybrid_pp and src_mesh is not None:
    g_mesh = dist.get_mesh()
    
    # Check if mesh exists and pipeline parallelism is enabled ("pp" dim size > 1)
    if g_mesh and "pp" in g_mesh.dim_names and g_mesh.get_dim_size("pp") > 1:
        
        # Get the pipeline parallelism subgroup for communication
        pp_group = g_mesh.get_submesh_with_dim("pp").get_group("pp")
        
        # Perform all-reduce on the local tensor value across the PP group
        global_norm_var_local = global_norm_var._local_value()
        dist.all_reduce(
            global_norm_var_local,
            op=dist.ReduceOp.SUM,
            group=pp_group,
        )
        
        # Re-shard the tensor with the reduced value
        global_norm_var = dist.shard_tensor(
            global_norm_var_local,
            global_norm_var.process_mesh,
            global_norm_var.placements,
        )

if self.should_comm_on_shard_dim and hasattr(self, 'sharding_group'):
paddle.distributed.all_reduce(
global_norm_var._local_value(), group=self.sharding_group
Expand Down
73 changes: 73 additions & 0 deletions test/auto_parallel/PP_Schedules_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,67 @@ def test_dp_pp(self):
opt.clear_grad()
return losses_by_step, all_losses_in_one_step_md5sum

def test_pp_model_with_ClipGradByGlobalNorm(self):
"""Test pipeline parallel model with ClipGradByGlobalNorm using PPMyModel as the baseline"""
fix_seeds()
pp_model = PPMyModel()
opt = paddle.optimizer.AdamW(
learning_rate=0.001,
parameters=pp_model.parameters(),
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
)
loss_fn = nn.MSELoss()
dataset = RandomDataset(image_size=8, output_size=8, num_samples=8)
loader = DataLoader(dataset, batch_size=1)
pp_losses_step = []
num_iterations = 20

for iter_idx in range(num_iterations):
pp_losses_micro_batch = []
for i, (data, label) in enumerate(loader):
output = pp_model(data)
loss = loss_fn(output, label)
pp_losses_micro_batch.append(loss.item())
loss.backward()
pp_losses_step.append(
np.array(pp_losses_micro_batch, dtype=np.float32).mean()
)
opt.step()
opt.clear_grad()
return pp_losses_step

def test_ScheduleFThenB_with_ClipGradByGlobalNorm(self):
fix_seeds()
self.model = PPMyModel_SingleStage()
self.micro_batches = 8
self.stage = PipelineStage(self.model, self.rank, 4, group=self.group)
self.stage.has_backward = True
loss_fn_ = nn.MSELoss()
schedule = ScheduleFThenB(
self.stage, self.micro_batches, loss_fn=loss_fn_
)
opt = paddle.optimizer.AdamW(
learning_rate=0.001,
parameters=self.model.parameters(),
grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0),
)
dataset = RandomDataset(image_size=8, output_size=8, num_samples=8)
loader = DataLoader(dataset, batch_size=8)
losses_by_step = []
num_iterations = 20

for iter_idx in range(num_iterations):
losses_by_micro_batch = []
for i, (data, label) in enumerate(loader):
schedule.step(data, target=label, losses=losses_by_micro_batch)
if self.rank == 3:
losses_by_step.append(
np.array(losses_by_micro_batch, dtype=np.float32).mean()
)
opt.step()
opt.clear_grad()
return losses_by_step

def test_dp_pp_align_mode(self):
fix_seeds()
paddle.set_flags({'FLAGS_enable_auto_parallel_align_mode': True})
Expand Down Expand Up @@ -490,6 +551,12 @@ def run_test(self):
scheduleFThenB_losses = self.test_ScheduleFThenB()
schedule1f1b_losses = self.test_Schedule1F1B()
schedulevpp_losses = self.test_ScheduleVPP()
pp_model_with_ClipGradByGlobalNorm_losses = (
self.test_pp_model_with_ClipGradByGlobalNorm()
)
scheduleFThenB_with_ClipGradByGlobalNorm_losses = (
self.test_ScheduleFThenB_with_ClipGradByGlobalNorm()
)
dp_pp_losses, dp_pp_losses_md5sum = self.test_dp_pp()
dp_pp_align_mode_losses, dp_pp_align_mode_losses_md5sum = (
self.test_dp_pp_align_mode()
Expand Down Expand Up @@ -520,6 +587,12 @@ def run_test(self):
rtol=1e-5,
)

np.testing.assert_allclose(
pp_model_with_ClipGradByGlobalNorm_losses,
scheduleFThenB_with_ClipGradByGlobalNorm_losses,
rtol=1e-5,
)

np.testing.assert_allclose(
dp_pp_align_mode_losses,
dp_pp_losses,
Expand Down