|
1583 | 1583 | backward: fused_feedforward_grad |
1584 | 1584 |
|
1585 | 1585 | - op: lars_momentum |
1586 | | - args: (Tensor param, Tensor velocity, Tensor grad, Tensor learning_rate, Tensor master_param, float mu, float lars_coeff=0.001f, float[] lars_weight_decay={0.0005}, float epsilon=0, bool multi_precision=false, float rescale_grad=1.0f) |
1587 | | - output: Tensor(param_out), Tensor(velocity_out), Tensor(master_param_out) |
| 1586 | + args: (Tensor[] param, Tensor[] grad, Tensor[] velocity, Tensor[] learning_rate, Tensor[] master_param, float mu, float lars_coeff=0.001f, float[] lars_weight_decay={0.0005f}, float epsilon=0.0f, bool multi_precision=false, float rescale_grad=1.0f) |
| 1587 | + output: Tensor[](param_out){param.size()}, Tensor[](velocity_out){param.size()}, Tensor[](master_param_out){param.size()} |
1588 | 1588 | infer_meta: |
1589 | | - func: SparseMomentumInferMeta |
1590 | | - param: [param, learning_rate, velocity] |
| 1589 | + func: LarsMomentumInferMeta |
| 1590 | + param: [param, velocity, learning_rate, grad, master_param, lars_weight_decay, mu, lars_coeff, epsilon, multi_precision, rescale_grad] |
1591 | 1591 | kernel: |
1592 | 1592 | func: lars_momentum |
1593 | 1593 | param: [param, velocity, learning_rate, grad, master_param, lars_weight_decay, mu, lars_coeff, epsilon, multi_precision, rescale_grad] |
1594 | 1594 | data_type: param |
1595 | 1595 | optional: master_param, master_param_out |
| 1596 | + inplace : master_param -> master_param_out |
1596 | 1597 |
|
1597 | 1598 | - op: match_matrix_tensor |
1598 | 1599 | args: (Tensor x, Tensor y, Tensor w, int dim_t=1) |
|
0 commit comments