@@ -13,55 +13,105 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include " paddle/fluid/framework/op_registry.h"
16+ #include " paddle/fluid/operators/amp/fp16_type_traits.h"
1617#include " paddle/fluid/operators/optimizers/lars_momentum_op.h"
1718
1819namespace paddle {
1920namespace operators {
2021
2122template <typename T>
22- __global__ void MomentumLarsKernel (const T* p, const T* g, const T* v,
23- const T* learning_rate, const T mu,
24- const int64_t num, const T lars_coeff,
25- const T lars_weight_decay, const T* p_norm,
26- const T* g_norm, T* p_out, T* v_out,
27- const T epsilon) {
28- T lr = learning_rate[0 ];
29- T local_lr = learning_rate[0 ];
23+ using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
24+
25+ template <typename T, typename MT>
26+ __global__ void MomentumLarsKernel (
27+ const T* p, const T* g, const MT* v,
28+ const MultiPrecisionType<T>* learning_rate, const MT mu, const int64_t num,
29+ const MT lars_coeff, const MT lars_weight_decay,
30+ const MultiPrecisionType<T>* p_norm, const MultiPrecisionType<T>* g_norm,
31+ T* p_out, MT* v_out, const MT epsilon, const MT* master_p, MT* master_p_out,
32+ const MultiPrecisionType<T> rescale_grad) {
33+ const MT lr = static_cast <MT>(learning_rate[0 ]);
34+ MT local_lr = lr;
35+ const MT p_n = static_cast <MT>(p_norm[0 ]);
36+ const MT g_n = static_cast <MT>(g_norm[0 ]);
37+
38+ if (lars_weight_decay > static_cast <MT>(0 ) && p_n > static_cast <MT>(0 ) &&
39+ g_n > static_cast <MT>(0 )) {
40+ local_lr =
41+ lr * lars_coeff * p_n / (g_n + lars_weight_decay * p_n + epsilon);
42+ }
3043 CUDA_KERNEL_LOOP (i, num) {
31- if (lars_weight_decay > 0 && p_norm[0 ] > 0 && g_norm[0 ] > 0 ) {
32- local_lr = lr * lars_coeff * p_norm[0 ] /
33- (g_norm[0 ] + lars_weight_decay * p_norm[0 ] + epsilon);
34- }
44+ MT grad = static_cast <MT>(g[i]) * static_cast <MT>(rescale_grad);
45+ MT param = master_p ? master_p[i] : static_cast <MT>(p[i]);
46+
47+ MT v_new = v[i] * mu + local_lr * (grad + lars_weight_decay * param);
48+ MT p_new = param - v_new;
3549
36- T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
3750 v_out[i] = v_new;
38- p_out[i] = p[i] - v_new;
51+ p_out[i] = static_cast <T>(p_new);
52+ if (master_p_out) master_p_out[i] = p_new;
3953 }
4054}
4155
4256template <typename DeviceContext, typename T>
4357class LarsMomentumOpCUDAKernel : public framework ::OpKernel<T> {
58+ using MPDType = MultiPrecisionType<T>;
59+
4460 public:
4561 void Compute (const framework::ExecutionContext& ctx) const override {
62+ const bool multi_precision = ctx.Attr <bool >(" multi_precision" );
63+ if (multi_precision) {
64+ InnerCompute<MPDType>(ctx, multi_precision);
65+ } else {
66+ InnerCompute<T>(ctx, multi_precision);
67+ }
68+ }
69+
70+ private:
71+ template <typename MT>
72+ void InnerCompute (const framework::ExecutionContext& ctx,
73+ const bool multi_precision) const {
4674 auto param_out = ctx.Output <framework::LoDTensor>(" ParamOut" );
4775 auto velocity_out = ctx.Output <framework::LoDTensor>(" VelocityOut" );
4876 auto param = ctx.Input <framework::LoDTensor>(" Param" );
4977 auto velocity = ctx.Input <framework::LoDTensor>(" Velocity" );
5078 auto grad = ctx.Input <framework::LoDTensor>(" Grad" );
5179 auto learning_rate = ctx.Input <framework::LoDTensor>(" LearningRate" );
5280
81+ const framework::Tensor* master_param = nullptr ;
82+ framework::Tensor* master_param_out = nullptr ;
83+ if (multi_precision) {
84+ bool has_master =
85+ ctx.HasInput (" MasterParam" ) && ctx.HasOutput (" MasterParamOut" );
86+ PADDLE_ENFORCE_EQ (has_master, true ,
87+ platform::errors::InvalidArgument (
88+ " The Input(MasterParam) and Output(MasterParamOut) "
89+ " should not be null when "
90+ " the attr `multi_precision` is true" ));
91+ master_param = ctx.Input <framework::Tensor>(" MasterParam" );
92+ master_param_out = ctx.Output <framework::Tensor>(" MasterParamOut" );
93+ }
94+
95+ const MT* master_p = multi_precision ? master_param->data <MT>() : nullptr ;
96+ MT* master_p_out = multi_precision
97+ ? master_param_out->mutable_data <MT>(ctx.GetPlace ())
98+ : nullptr ;
99+
53100 T* p_out = param_out->mutable_data <T>(ctx.GetPlace ());
54- T * v_out = velocity_out->mutable_data <T >(ctx.GetPlace ());
101+ MT * v_out = velocity_out->mutable_data <MT >(ctx.GetPlace ());
55102
56- T mu = static_cast <T>(ctx.Attr <float >(" mu" ));
57- T lars_coeff = ctx.Attr <float >(" lars_coeff" );
58- T lars_weight_decay = ctx.Attr <float >(" lars_weight_decay" );
59- T epsilon = ctx.Attr <float >(" epsilon" );
103+ MT mu = static_cast <MT>(ctx.Attr <float >(" mu" ));
104+ MT lars_coeff = static_cast <MT>(ctx.Attr <float >(" lars_coeff" ));
105+ MT lars_weight_decay =
106+ static_cast <MT>(ctx.Attr <float >(" lars_weight_decay" ));
107+ MT epsilon = static_cast <MT>(ctx.Attr <float >(" epsilon" ));
108+ MPDType rescale_grad =
109+ static_cast <MPDType>(ctx.Attr <float >(" rescale_grad" ));
60110
61111 auto * p = param->data <T>();
62- auto * v = velocity->data <T>();
63112 auto * g = grad->data <T>();
64- auto * lr = learning_rate->data <T>();
113+ auto * v = velocity->data <MT>();
114+ auto * lr = learning_rate->data <MPDType>();
65115
66116 int block = 512 ;
67117 int grid = (param->numel () + block - 1 ) / block;
@@ -72,17 +122,24 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
72122 framework::Tensor p_norm_t , g_norm_t ;
73123 p_norm_t .Resize ({1 });
74124 g_norm_t .Resize ({1 });
75- auto * p_norm_data = p_norm_t .mutable_data <T >(ctx.GetPlace ());
76- auto * g_norm_data = g_norm_t .mutable_data <T >(ctx.GetPlace ());
77- auto ep_norm = framework::EigenScalar<T >::From (p_norm_t );
78- auto eg_norm = framework::EigenScalar<T >::From (g_norm_t );
125+ auto * p_norm_data = p_norm_t .mutable_data <MPDType >(ctx.GetPlace ());
126+ auto * g_norm_data = g_norm_t .mutable_data <MPDType >(ctx.GetPlace ());
127+ auto ep_norm = framework::EigenScalar<MPDType >::From (p_norm_t );
128+ auto eg_norm = framework::EigenScalar<MPDType >::From (g_norm_t );
79129
80130 auto * place = ctx.template device_context <DeviceContext>().eigen_device ();
81- ep_norm.device (*place) = eigen_p.square ().sum ().sqrt ();
82- eg_norm.device (*place) = eigen_g.square ().sum ().sqrt ();
83- MomentumLarsKernel<<<grid, block, 0 , ctx.cuda_device_context().stream()>>> (
131+
132+ // eigen unsupport fp16 l2-norm
133+ ep_norm.device (*place) =
134+ eigen_p.template cast <MPDType>().square ().sum ().sqrt ();
135+ eg_norm.device (*place) =
136+ (eigen_g.template cast <MPDType>() * rescale_grad).square ().sum ().sqrt ();
137+
138+ MomentumLarsKernel<
139+ T, MT><<<grid, block, 0 , ctx.cuda_device_context().stream()>>> (
84140 p, g, v, lr, mu, param->numel (), lars_coeff, lars_weight_decay,
85- p_norm_data, g_norm_data, p_out, v_out, epsilon);
141+ p_norm_data, g_norm_data, p_out, v_out, epsilon, master_p, master_p_out,
142+ rescale_grad);
86143 }
87144};
88145
@@ -93,4 +150,6 @@ namespace ops = paddle::operators;
93150REGISTER_OP_CUDA_KERNEL (
94151 lars_momentum,
95152 ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float >,
96- ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double >);
153+ ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double >,
154+ ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext,
155+ paddle::platform::float16>);
0 commit comments