Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4654,15 +4654,22 @@ def _add_op_device_attr_for_op(self, op, idx, block):
op.type == 'elementwise_div'):
device = f"{self._device}:all"
op._set_attr(self._op_device_key, device)
elif self._is_weight_decay_op(op) and op.type == 'scale':
# set AdamW decay_coeff to device:all
op._set_attr(self._op_device_key, f"{self._device}:all")
elif op.type == "alloc_float_status" or op.type == "clear_float_status":
op._set_attr(self._op_device_key, f"{self._device}:all")
# NOTE(wangxi): NPU should only clear the float status
# once at each batch step
op._set_attr(self._op_role_key, self._op_role.LRSched)

float_status_name = op.output_arg_names[0]
float_status_var = block.var(float_status_name)
# FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0)
# while update will exec on sub_scope(last_micro_step), should
# set persistable to use global scope
float_status_var.persistable = True
else:
other_known_ops = [
'update_loss_scaling', 'reduce_any', 'concat', 'sum',
'check_finite_and_unscale', 'alloc_float_status', 'memcpy'
'check_finite_and_unscale', 'memcpy'
]
assert op.type in other_known_ops, "For other ops without " \
"op_device set, they must be one of {}, but it " \
Expand Down