Skip to content

Commit 32ca78a

Browse files
authored
【PIR Dist Op Reg No.18】 reg lars_momentum (#60838)
* fix * fix * fix * fix * fix * add f * fix * change sequence * add lars_momentum_ * fix * fix
1 parent 7305dae commit 32ca78a

3 files changed

Lines changed: 9 additions & 7 deletions

File tree

paddle/fluid/operators/optimizers/lars_momentum_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,14 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
7171
.AsDispensable();
7272
AddAttr<float>("mu", "(float) Momentum coefficient");
7373
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
74-
.SetDefault(0.001);
74+
.SetDefault(0.001f);
7575
AddAttr<std::vector<float>>(
7676
"lars_weight_decay",
7777
"(std::vector<float>, default 0.0005) LARS weight decay params")
78-
.SetDefault({0.0005});
78+
.SetDefault({0.0005f});
7979
AddAttr<float>("epsilon",
8080
"(float, default 0.0) epsilon to avoid Division by Zero.")
81-
.SetDefault(0.0);
81+
.SetDefault(0.0f);
8282
AddAttr<bool>("multi_precision",
8383
"(bool, default false) "
8484
"Whether to use multi-precision during weight updating.")

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137
'fused_dot_product_attention',
138138
'nce',
139139
'lars_momentum',
140+
'lars_momentum_',
140141
'max_pool2d_v2',
141142
'recv_v2',
142143
'rnn_',

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,16 +1583,17 @@
15831583
backward: fused_feedforward_grad
15841584

15851585
- op: lars_momentum
1586-
args: (Tensor param, Tensor velocity, Tensor grad, Tensor learning_rate, Tensor master_param, float mu, float lars_coeff=0.001f, float[] lars_weight_decay={0.0005}, float epsilon=0, bool multi_precision=false, float rescale_grad=1.0f)
1587-
output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)
1586+
args: (Tensor[] param, Tensor[] grad, Tensor[] velocity, Tensor[] learning_rate, Tensor[] master_param, float mu, float lars_coeff=0.001f, float[] lars_weight_decay={0.0005f}, float epsilon=0.0f, bool multi_precision=false, float rescale_grad=1.0f)
1587+
output: Tensor[](param_out){param.size()}, Tensor[](velocity_out){param.size()}, Tensor[](master_param_out){param.size()}
15881588
infer_meta:
1589-
func: SparseMomentumInferMeta
1590-
param: [param, learning_rate, velocity]
1589+
func: LarsMomentumInferMeta
1590+
param: [param, velocity, learning_rate, grad, master_param, lars_weight_decay, mu, lars_coeff, epsilon, multi_precision, rescale_grad]
15911591
kernel:
15921592
func: lars_momentum
15931593
param: [param, velocity, learning_rate, grad, master_param, lars_weight_decay, mu, lars_coeff, epsilon, multi_precision, rescale_grad]
15941594
data_type: param
15951595
optional: master_param, master_param_out
1596+
inplace : master_param -> master_param_out
15961597

15971598
- op: match_matrix_tensor
15981599
args: (Tensor x, Tensor y, Tensor w, int dim_t=1)

0 commit comments

Comments
 (0)