Skip to content

Commit 59d8b8c

Browse files
authored
[HybridParallel]fix bug of check_inf in fleet_base.py (#36651)
* fix bug of check_inf * fix allreduce
1 parent 50778ad commit 59d8b8c

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,16 +1586,16 @@ def unscale_method(self, optimizer):
15861586
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
15871587
param_grads_fp32,
15881588
temp_found_inf_fp32)
1589+
15891590
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
1591+
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
15901592

15911593
# TODO(shenliang03) Since dp allreduce in the optimizer is
15921594
# after the gradscaler, check_finite needs to synchronize global
15931595
# information. In the future, we should use check_group to speed.
15941596
paddle.distributed.all_reduce(
1595-
paddle.to_tensor(
1596-
[self._found_inf], dtype="int32"),
1597-
op=paddle.distributed.ReduceOp.MAX,
1598-
group=None)
1597+
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
1598+
self._found_inf = is_found_inf.numpy()[0]
15991599

16001600
# Only tensor_parallel and pipeline_parallel need to modify scaler
16011601
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,

python/paddle/distributed/fleet/utils/hybrid_parallel_util.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,14 @@ def _apply_collective_grads(parameters, comm_group):
4747
nranks = paddle.distributed.get_world_size(
4848
) if comm_group is None else comm_group.nranks
4949
div_factor = paddle.to_tensor(nranks, dtype=coalesced_grad.dtype)
50+
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
5051
paddle.fluid.framework._dygraph_tracer().trace_op(
5152
type="elementwise_div",
5253
inputs={'X': coalesced_grad,
5354
'Y': div_factor},
5455
outputs={'Out': coalesced_grad},
5556
attrs={'axis': -1})
5657

57-
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
58-
5958
_split_tensors(coalesced_grads_and_vars)
6059

6160

0 commit comments

Comments
 (0)