Skip to content

Commit 8c26c38

Browse files
committed
fix npu clear float status in pipeline
1 parent 537cee9 commit 8c26c38

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4659,6 +4659,9 @@ def _add_op_device_attr_for_op(self, op, idx, block):
46594659
op._set_attr(self._op_device_key, f"{self._device}:all")
46604660
elif op.type == "alloc_float_status" or op.type == "clear_float_status":
46614661
op._set_attr(self._op_device_key, f"{self._device}:all")
4662+
# NOTE(wangxi): NPU should only clear the float status
4663+
# once at each batch step
4664+
op._set_attr(self._op_role_key, self._op_role.LRSched)
46624665
else:
46634666
other_known_ops = [
46644667
'update_loss_scaling', 'reduce_any', 'concat', 'sum',

0 commit comments

Comments
 (0)