@@ -21,10 +21,36 @@ namespace operators {
2121using Tensor = framework::Tensor;
2222using DDim = framework::DDim;
2323
24+ using DataLayout = framework::DataLayout;
25+
26+ template <typename T>
27+ class NormDataType ;
28+
29+ template <>
30+ class NormDataType <platform::float16> {
31+ public:
32+ // The scaling param type is float for HALF and FLOAT tensors
33+ using ScalingParamType = const float ;
34+ using BatchNormParamType = float ;
35+ };
36+
37+ template <>
38+ class NormDataType <float > {
39+ public:
40+ using ScalingParamType = const float ;
41+ using BatchNormParamType = float ;
42+ };
43+
44+ template <typename T>
45+ using NormDataType = NormDataType<T>;
46+ template <typename T>
47+ using LayerNormParamType = typename NormDataType<T>::BatchNormParamType;
48+
2449template <typename T>
2550class LayerNormNPUKernel : public framework ::OpKernel<T> {
2651 public:
2752 void Compute (const framework::ExecutionContext& ctx) const override {
53+ using U = LayerNormParamType<T>;
2854 const auto begin_norm_axis = ctx.Attr <int >(" begin_norm_axis" );
2955 const auto epsilon = ctx.Attr <float >(" epsilon" );
3056 const auto * x = ctx.Input <Tensor>(" X" );
@@ -43,6 +69,7 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
4369 for (auto i = begin_norm_axis; i < x_dims.size (); ++i) {
4470 axes.push_back (x_dims[i]);
4571 }
72+
4673 auto place = ctx.GetPlace ();
4774 auto stream =
4875 ctx.template device_context <paddle::platform::NPUDeviceContext>()
@@ -77,16 +104,93 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
77104 } else {
78105 const_cast <Tensor*>(bias)->Resize (framework::make_ddim (axes));
79106 }
107+
108+ // cast scale from LayerNormParamType to T if needed
109+ Tensor cast_scale (x->type ());
110+ if (x->type () == framework::proto::VarType::FP16 &&
111+ scale->type () == framework::proto::VarType::FP32) {
112+ cast_scale.Resize (scale->dims ());
113+ cast_scale.mutable_data <T>(ctx.GetPlace ());
114+ auto dst_dtype = ConvertToNpuDtype (x->type ());
115+ auto runner_cast_scale =
116+ NpuOpRunner (" Cast" , {*scale}, {cast_scale},
117+ {{" dst_type" , static_cast <int >(dst_dtype)}});
118+ runner_cast_scale.Run (stream);
119+ } else {
120+ cast_scale.ShareDataWith (*scale);
121+ }
122+
123+ // cast bias from LayerNormParamType to T if needed
124+ Tensor cast_bias (x->type ());
125+ if (x->type () == framework::proto::VarType::FP16 &&
126+ bias->type () == framework::proto::VarType::FP32) {
127+ cast_bias.Resize (bias->dims ());
128+ cast_bias.mutable_data <T>(ctx.GetPlace ());
129+ auto dst_dtype = ConvertToNpuDtype (x->type ());
130+ auto runner_cast_bias =
131+ NpuOpRunner (" Cast" , {*bias}, {cast_bias},
132+ {{" dst_type" , static_cast <int >(dst_dtype)}});
133+ runner_cast_bias.Run (stream);
134+ } else {
135+ cast_bias.ShareDataWith (*bias);
136+ }
137+
80138 y->mutable_data <T>(ctx.GetPlace ());
81- mean->mutable_data <T>(ctx.GetPlace ());
82- variance->mutable_data <T>(ctx.GetPlace ());
83-
84- auto runner =
85- NpuOpRunner (" LayerNorm" , {*x, *scale, *bias}, {*y, *mean, *variance},
86- {{" begin_norm_axis" , begin_norm_axis},
87- {" begin_params_axis" , begin_norm_axis},
88- {" epsilon" , epsilon}});
139+
140+ // mean should be of U type
141+ Tensor* tmp_mean = mean;
142+ Tensor cast_mean (x->type ());
143+ if (x->type () == framework::proto::VarType::FP16 &&
144+ (scale->type () == framework::proto::VarType::FP32 ||
145+ bias->type () == framework::proto::VarType::FP32)) {
146+ cast_mean.Resize (mean->dims ());
147+ cast_mean.mutable_data <T>(ctx.GetPlace ());
148+ tmp_mean = &cast_mean;
149+ mean->mutable_data <U>(ctx.GetPlace ());
150+ } else {
151+ mean->mutable_data <T>(ctx.GetPlace ());
152+ }
153+
154+ // same for variance
155+ Tensor* tmp_variance = variance;
156+ Tensor cast_variance (x->type ());
157+ if (x->type () == framework::proto::VarType::FP16 &&
158+ (scale->type () == framework::proto::VarType::FP32 ||
159+ bias->type () == framework::proto::VarType::FP32)) {
160+ cast_variance.Resize (variance->dims ());
161+ cast_variance.mutable_data <T>(ctx.GetPlace ());
162+ tmp_variance = &cast_variance;
163+ variance->mutable_data <U>(ctx.GetPlace ());
164+ } else {
165+ variance->mutable_data <T>(ctx.GetPlace ());
166+ }
167+
168+ auto runner = NpuOpRunner (" LayerNorm" , {*x, cast_scale, cast_bias},
169+ {*y, *tmp_mean, *tmp_variance},
170+ {{" begin_norm_axis" , begin_norm_axis},
171+ {" begin_params_axis" , begin_norm_axis},
172+ {" epsilon" , epsilon}});
89173 runner.Run (stream);
174+
175+ // cast back from FP16 to FP32
176+ if (x->type () == framework::proto::VarType::FP16 &&
177+ mean->type () == framework::proto::VarType::FP32) {
178+ auto dst_dtype = ConvertToNpuDtype (mean->type ());
179+ auto runner_cast_mean =
180+ NpuOpRunner (" Cast" , {*tmp_mean}, {*mean},
181+ {{" dst_type" , static_cast <int >(dst_dtype)}});
182+ runner_cast_mean.Run (stream);
183+ }
184+ // same for variance
185+ if (x->type () == framework::proto::VarType::FP16 &&
186+ variance->type () == framework::proto::VarType::FP32) {
187+ auto dst_dtype = ConvertToNpuDtype (variance->type ());
188+ auto runner_cast_variance =
189+ NpuOpRunner (" Cast" , {*tmp_variance}, {*variance},
190+ {{" dst_type" , static_cast <int >(dst_dtype)}});
191+ runner_cast_variance.Run (stream);
192+ }
193+
90194 // revert shape of scale and bias
91195 // TODO(zhiqiu): better implementation, use tmp tensor to avoid write input
92196 // tensor.
@@ -99,6 +203,7 @@ template <typename T>
99203class LayerNormGradNPUKernel : public framework ::OpKernel<T> {
100204 public:
101205 void Compute (const framework::ExecutionContext& ctx) const override {
206+ using U = LayerNormParamType<T>;
102207 const auto begin_norm_axis = ctx.Attr <int >(" begin_norm_axis" );
103208 const auto * x = ctx.Input <Tensor>(" X" );
104209 const auto & x_dims = x->dims ();
@@ -156,25 +261,115 @@ class LayerNormGradNPUKernel : public framework::OpKernel<T> {
156261 const_cast <Tensor*>(scale)->Resize (framework::make_ddim (axes));
157262 }
158263
264+ // cast scale from LayerNormParamType to T if needed
265+ Tensor cast_scale (x->type ());
266+ if (x->type () == framework::proto::VarType::FP16 &&
267+ scale->type () == framework::proto::VarType::FP32) {
268+ cast_scale.Resize (scale->dims ());
269+ cast_scale.mutable_data <T>(ctx.GetPlace ());
270+ auto dst_dtype = ConvertToNpuDtype (x->type ());
271+ auto runner_cast_scale =
272+ NpuOpRunner (" Cast" , {*scale}, {cast_scale},
273+ {{" dst_type" , static_cast <int >(dst_dtype)}});
274+ runner_cast_scale.Run (stream);
275+ } else {
276+ cast_scale.ShareDataWith (*scale);
277+ }
278+
279+ // cast mean from LayerNormParamType to T if needed
280+ Tensor cast_mean (x->type ());
281+ if (x->type () == framework::proto::VarType::FP16 &&
282+ mean->type () == framework::proto::VarType::FP32) {
283+ cast_mean.Resize (mean->dims ());
284+ cast_mean.mutable_data <T>(ctx.GetPlace ());
285+ auto dst_dtype = ConvertToNpuDtype (x->type ());
286+ auto runner_cast_mean =
287+ NpuOpRunner (" Cast" , {*mean}, {cast_mean},
288+ {{" dst_type" , static_cast <int >(dst_dtype)}});
289+ runner_cast_mean.Run (stream);
290+ } else {
291+ cast_mean.ShareDataWith (*mean);
292+ }
293+
294+ // cast variance from LayerNormParamType to T if needed
295+ Tensor cast_variance (x->type ());
296+ if (x->type () == framework::proto::VarType::FP16 &&
297+ variance->type () == framework::proto::VarType::FP32) {
298+ cast_variance.Resize (variance->dims ());
299+ cast_variance.mutable_data <T>(ctx.GetPlace ());
300+ auto dst_dtype = ConvertToNpuDtype (x->type ());
301+ auto runner_cast_variance =
302+ NpuOpRunner (" Cast" , {*variance}, {cast_variance},
303+ {{" dst_type" , static_cast <int >(dst_dtype)}});
304+ runner_cast_variance.Run (stream);
305+ } else {
306+ cast_variance.ShareDataWith (*variance);
307+ }
308+
159309 Tensor dx_ (dy->type ()), dscale_ (dy->type ()), dbias_ (dy->type ());
160310 dx = (dx == nullptr ) ? &dx_ : dx;
161311 dscale = (dscale == nullptr ) ? &dscale_ : dscale;
162312 dbias = (dbias == nullptr ) ? &dbias_ : dbias;
163313
314+ dx->Resize (x->dims ());
315+ dx->mutable_data <T>(ctx.GetPlace ());
316+
164317 dscale->Resize (framework::make_ddim (axes));
165- dscale->mutable_data <T>(ctx.GetPlace ());
166318
167319 dbias->Resize (framework::make_ddim (axes));
168- dbias->mutable_data <T>(ctx.GetPlace ());
169320
170- dx->Resize (x->dims ());
171- dx->mutable_data <T>(ctx.GetPlace ());
321+ // dscale should be of U type
322+ Tensor* tmp_dscale = dscale;
323+ Tensor cast_dscale (x->type ());
324+ if (x->type () == framework::proto::VarType::FP16 &&
325+ (mean->type () == framework::proto::VarType::FP32 ||
326+ variance->type () == framework::proto::VarType::FP32)) {
327+ cast_dscale.Resize (dscale->dims ());
328+ cast_dscale.mutable_data <T>(ctx.GetPlace ());
329+ tmp_dscale = &cast_dscale;
330+ dscale->mutable_data <U>(ctx.GetPlace ());
331+ } else {
332+ dscale->mutable_data <T>(ctx.GetPlace ());
333+ }
172334
173- auto runner =
174- NpuOpRunner (" LayerNormGrad" , {*dy, *x, *variance, *mean, *scale},
175- {*dx, *dscale, *dbias}, {});
335+ // same for dbias
336+ Tensor* tmp_dbias = dbias;
337+ Tensor cast_dbias (x->type ());
338+ if (x->type () == framework::proto::VarType::FP16 &&
339+ (mean->type () == framework::proto::VarType::FP32 ||
340+ variance->type () == framework::proto::VarType::FP32)) {
341+ cast_dbias.Resize (dbias->dims ());
342+ cast_dbias.mutable_data <T>(ctx.GetPlace ());
343+ tmp_dbias = &cast_dbias;
344+ dbias->mutable_data <U>(ctx.GetPlace ());
345+ } else {
346+ dbias->mutable_data <T>(ctx.GetPlace ());
347+ }
348+
349+ auto runner = NpuOpRunner (" LayerNormGrad" ,
350+ {*dy, *x, cast_variance, cast_mean, cast_scale},
351+ {*dx, *tmp_dscale, *tmp_dbias}, {});
176352 runner.Run (stream);
177353
354+ // cast back from FP16 to FP32
355+ if (x->type () == framework::proto::VarType::FP16 &&
356+ dscale->type () == framework::proto::VarType::FP32) {
357+ auto dst_dtype = ConvertToNpuDtype (dscale->type ());
358+ auto runner_cast_dscale =
359+ NpuOpRunner (" Cast" , {*tmp_dscale}, {*dscale},
360+ {{" dst_type" , static_cast <int >(dst_dtype)}});
361+ runner_cast_dscale.Run (stream);
362+ }
363+ // same for dbias
364+ if (x->type () == framework::proto::VarType::FP16 &&
365+ dbias->type () == framework::proto::VarType::FP32) {
366+ auto dst_dtype = ConvertToNpuDtype (dbias->type ());
367+ auto runner_cast_dbias =
368+ NpuOpRunner (" Cast" , {*tmp_dbias}, {*dbias},
369+ {{" dst_type" , static_cast <int >(dst_dtype)}});
370+ runner_cast_dbias.Run (stream);
371+ }
372+
178373 const_cast <Tensor*>(mean)->Resize (mean_dims);
179374 const_cast <Tensor*>(variance)->Resize (mean_dims);
180375 const_cast <Tensor*>(scale)->Resize (framework::make_ddim ({right}));
0 commit comments