Skip to content

Commit e32d596

Browse files
committed
fix npu found_finite in hybrid
1 parent 881e55e commit e32d596

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,11 @@ def _adapt_amp_clip_without_sharding(self):
371371
# FIXME(wangxi): mp should prune duplicated param_grads when calc
372372
# amp inf_var & clip global_norm_var
373373

374-
FP16Utils.sync_amp_check_nan_inf(main_block,
375-
[self.mp_ring_id, self.pp_ring_id])
374+
rings = [self.mp_ring_id, self.pp_ring_id]
375+
# FIXME(wangxi): some problem with NPU found_finite, need sync with DP
376+
if core.is_compiled_with_npu():
377+
rings += [self.dp_ring_id]
378+
FP16Utils.sync_amp_check_nan_inf(main_block, rings)
376379

377380
gradientclip_helper = GradientClipHelper(None)
378381
gradientclip_helper.sync_global_norm(

0 commit comments

Comments
 (0)