@@ -4654,18 +4654,22 @@ def _add_op_device_attr_for_op(self, op, idx, block):
46544654 op .type == 'elementwise_div' ):
46554655 device = f"{ self ._device } :all"
46564656 op ._set_attr (self ._op_device_key , device )
4657- elif self ._is_weight_decay_op (op ) and op .type == 'scale' :
4658- # set AdamW decay_coeff to device:all
4659- op ._set_attr (self ._op_device_key , f"{ self ._device } :all" )
46604657 elif op .type == "alloc_float_status" or op .type == "clear_float_status" :
46614658 op ._set_attr (self ._op_device_key , f"{ self ._device } :all" )
46624659 # NOTE(wangxi): NPU should only clear the float status
46634660 # once at each batch step
46644661 op ._set_attr (self ._op_role_key , self ._op_role .LRSched )
4662+
4663+ float_status_name = op .output_arg_names [0 ]
4664+ float_status_var = block .var (float_status_name )
4665+ # FIXME(wangxi): pipeline lr schedule will exec on sub_scope(0)
4666+ # while update will exec on sub_scope(last_micro_step), should
4667+ # set persistable to use global scope
4668+ float_status_var .persistable = True
46654669 else :
46664670 other_known_ops = [
46674671 'update_loss_scaling' , 'reduce_any' , 'concat' , 'sum' ,
4668- 'check_finite_and_unscale' , 'alloc_float_status' , ' memcpy'
4672+ 'check_finite_and_unscale' , 'memcpy'
46694673 ]
46704674 assert op .type in other_known_ops , "For other ops without " \
46714675 "op_device set, they must be one of {}, but it " \
0 commit comments