Skip to content

Commit 45a2af0

Browse files
committed
fix pipeline float status uninitialized
1 parent 8c26c38 commit 45a2af0

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4654,18 +4654,22 @@ def _add_op_device_attr_for_op(self, op, idx, block):
46544654
op.type == 'elementwise_div'):
46554655
device = f"{self._device}:all"
46564656
op._set_attr(self._op_device_key, device)
4657-
elif self._is_weight_decay_op(op) and op.type == 'scale':
4658-
# set AdamW decay_coeff to device:all
4659-
op._set_attr(self._op_device_key, f"{self._device}:all")
46604657
elif op.type == "alloc_float_status" or op.type == "clear_float_status":
46614658
op._set_attr(self._op_device_key, f"{self._device}:all")
46624659
# NOTE(wangxi): NPU should only clear the float status
46634660
# once at each batch step
46644661
op._set_attr(self._op_role_key, self._op_role.LRSched)
4662+
4663+
float_status_name = op.output_arg_names[0]
4664+
float_status_var = block.var(float_status_name)
4665+
# FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0)
4666+
# while update will exec on sub_scope(last_micro_step), should
4667+
# set persistable to use global scope
4668+
float_status_var.persistable = True
46654669
else:
46664670
other_known_ops = [
46674671
'update_loss_scaling', 'reduce_any', 'concat', 'sum',
4668-
'check_finite_and_unscale', 'alloc_float_status', 'memcpy'
4672+
'check_finite_and_unscale', 'memcpy'
46694673
]
46704674
assert op.type in other_known_ops, "For other ops without " \
46714675
"op_device set, they must be one of {}, but it " \

0 commit comments

Comments
 (0)