Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@ class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDispensable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
.SetDefault(0.001);
.SetDefault(0.001f);
AddAttr<std::vector<float>>(
"lars_weight_decay",
"(std::vector<float>, default 0.0005) LARS weight decay params")
.SetDefault({0.0005});
.SetDefault({0.0005f});
AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0);
.SetDefault(0.0f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
Expand Down
9 changes: 4 additions & 5 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1503,16 +1503,15 @@
backward: fused_feedforward_grad

- op: lars_momentum
args: (Tensor param, Tensor velocity, Tensor grad, Tensor learning_rate, Tensor master_param, float mu, float lars_coeff=0.001f, float[] lars_weight_decay={0.0005}, float epsilon=0, bool multi_precision=false, float rescale_grad=1.0f)
output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out)
args: (Tensor[] param, Tensor[] velocity, Tensor[] learning_rate, Tensor[] grad, Tensor[] master_param, float[] lars_weight_decay={0.0005f}, float mu, float lars_coeff=0.001f, float epsilon=0.0f, bool multi_precision=false, float rescale_grad=1.0f)
output: Tensor[](param_out){param.size()}, Tensor[](velocity_out){param.size()}, Tensor[](master_param_out){param.size()}
infer_meta:
func: SparseMomentumInferMeta
param: [param, learning_rate, velocity]
func: LarsMomentumInferMeta
kernel:
func: lars_momentum
param: [param, velocity, learning_rate, grad, master_param, lars_weight_decay, mu, lars_coeff, epsilon, multi_precision, rescale_grad]
data_type: param
optional: master_param, master_param_out
inplace : master_param -> master_param_out

- op: nce
args: (Tensor input, Tensor label, Tensor weight, Tensor bias, Tensor sample_weight, Tensor custom_dist_probs, Tensor custom_dist_alias, Tensor custom_dist_alias_probs, int num_total_classes, int[] custom_neg_classes={}, int num_neg_samples=10, int sampler=0, int seed=0, bool is_sparse=false, bool remote_prefetch=false, bool is_test=false)
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ const std::unordered_set<std::string> LegacyOpList = {
paddle::onednn::dialect::LrnGradOp::name(),
#endif
CReduceMinOp::name(),
PushSparseV2Op::name()};
PushSparseV2Op::name(),
LarsMomentumOp::name()};

enum class AttrType {
UNDEFINED = 0,
Expand Down