Skip to content

Commit 25aa56b

Browse files
pangyokizhiqiu
authored andcommitted
[NPU] support mixed precision input for npu layer norm (PaddlePaddle#31847)
* support mixed precision input for npu layer norm * fix layer_norm npu kernel Co-authored-by: zhiqiu <[email protected]>
1 parent 5bcfa8a commit 25aa56b

File tree

2 files changed

+229
-22
lines changed

2 files changed

+229
-22
lines changed

paddle/fluid/operators/layer_norm_op_npu.cc

Lines changed: 210 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,36 @@ namespace operators {
2121
using Tensor = framework::Tensor;
2222
using DDim = framework::DDim;
2323

24+
using DataLayout = framework::DataLayout;
25+
26+
template <typename T>
27+
class NormDataType;
28+
29+
template <>
30+
class NormDataType<platform::float16> {
31+
public:
32+
// The scaling param type is float for HALF and FLOAT tensors
33+
using ScalingParamType = const float;
34+
using BatchNormParamType = float;
35+
};
36+
37+
template <>
38+
class NormDataType<float> {
39+
public:
40+
using ScalingParamType = const float;
41+
using BatchNormParamType = float;
42+
};
43+
44+
template <typename T>
45+
using NormDataType = NormDataType<T>;
46+
template <typename T>
47+
using LayerNormParamType = typename NormDataType<T>::BatchNormParamType;
48+
2449
template <typename T>
2550
class LayerNormNPUKernel : public framework::OpKernel<T> {
2651
public:
2752
void Compute(const framework::ExecutionContext& ctx) const override {
53+
using U = LayerNormParamType<T>;
2854
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
2955
const auto epsilon = ctx.Attr<float>("epsilon");
3056
const auto* x = ctx.Input<Tensor>("X");
@@ -43,6 +69,7 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
4369
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
4470
axes.push_back(x_dims[i]);
4571
}
72+
4673
auto place = ctx.GetPlace();
4774
auto stream =
4875
ctx.template device_context<paddle::platform::NPUDeviceContext>()
@@ -77,16 +104,93 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
77104
} else {
78105
const_cast<Tensor*>(bias)->Resize(framework::make_ddim(axes));
79106
}
107+
108+
// cast scale from LayerNormParamType to T if needed
109+
Tensor cast_scale(x->type());
110+
if (x->type() == framework::proto::VarType::FP16 &&
111+
scale->type() == framework::proto::VarType::FP32) {
112+
cast_scale.Resize(scale->dims());
113+
cast_scale.mutable_data<T>(ctx.GetPlace());
114+
auto dst_dtype = ConvertToNpuDtype(x->type());
115+
auto runner_cast_scale =
116+
NpuOpRunner("Cast", {*scale}, {cast_scale},
117+
{{"dst_type", static_cast<int>(dst_dtype)}});
118+
runner_cast_scale.Run(stream);
119+
} else {
120+
cast_scale.ShareDataWith(*scale);
121+
}
122+
123+
// cast bias from LayerNormParamType to T if needed
124+
Tensor cast_bias(x->type());
125+
if (x->type() == framework::proto::VarType::FP16 &&
126+
bias->type() == framework::proto::VarType::FP32) {
127+
cast_bias.Resize(bias->dims());
128+
cast_bias.mutable_data<T>(ctx.GetPlace());
129+
auto dst_dtype = ConvertToNpuDtype(x->type());
130+
auto runner_cast_bias =
131+
NpuOpRunner("Cast", {*bias}, {cast_bias},
132+
{{"dst_type", static_cast<int>(dst_dtype)}});
133+
runner_cast_bias.Run(stream);
134+
} else {
135+
cast_bias.ShareDataWith(*bias);
136+
}
137+
80138
y->mutable_data<T>(ctx.GetPlace());
81-
mean->mutable_data<T>(ctx.GetPlace());
82-
variance->mutable_data<T>(ctx.GetPlace());
83-
84-
auto runner =
85-
NpuOpRunner("LayerNorm", {*x, *scale, *bias}, {*y, *mean, *variance},
86-
{{"begin_norm_axis", begin_norm_axis},
87-
{"begin_params_axis", begin_norm_axis},
88-
{"epsilon", epsilon}});
139+
140+
// mean should be of U type
141+
Tensor* tmp_mean = mean;
142+
Tensor cast_mean(x->type());
143+
if (x->type() == framework::proto::VarType::FP16 &&
144+
(scale->type() == framework::proto::VarType::FP32 ||
145+
bias->type() == framework::proto::VarType::FP32)) {
146+
cast_mean.Resize(mean->dims());
147+
cast_mean.mutable_data<T>(ctx.GetPlace());
148+
tmp_mean = &cast_mean;
149+
mean->mutable_data<U>(ctx.GetPlace());
150+
} else {
151+
mean->mutable_data<T>(ctx.GetPlace());
152+
}
153+
154+
// same for variance
155+
Tensor* tmp_variance = variance;
156+
Tensor cast_variance(x->type());
157+
if (x->type() == framework::proto::VarType::FP16 &&
158+
(scale->type() == framework::proto::VarType::FP32 ||
159+
bias->type() == framework::proto::VarType::FP32)) {
160+
cast_variance.Resize(variance->dims());
161+
cast_variance.mutable_data<T>(ctx.GetPlace());
162+
tmp_variance = &cast_variance;
163+
variance->mutable_data<U>(ctx.GetPlace());
164+
} else {
165+
variance->mutable_data<T>(ctx.GetPlace());
166+
}
167+
168+
auto runner = NpuOpRunner("LayerNorm", {*x, cast_scale, cast_bias},
169+
{*y, *tmp_mean, *tmp_variance},
170+
{{"begin_norm_axis", begin_norm_axis},
171+
{"begin_params_axis", begin_norm_axis},
172+
{"epsilon", epsilon}});
89173
runner.Run(stream);
174+
175+
// cast back from FP16 to FP32
176+
if (x->type() == framework::proto::VarType::FP16 &&
177+
mean->type() == framework::proto::VarType::FP32) {
178+
auto dst_dtype = ConvertToNpuDtype(mean->type());
179+
auto runner_cast_mean =
180+
NpuOpRunner("Cast", {*tmp_mean}, {*mean},
181+
{{"dst_type", static_cast<int>(dst_dtype)}});
182+
runner_cast_mean.Run(stream);
183+
}
184+
// same for variance
185+
if (x->type() == framework::proto::VarType::FP16 &&
186+
variance->type() == framework::proto::VarType::FP32) {
187+
auto dst_dtype = ConvertToNpuDtype(variance->type());
188+
auto runner_cast_variance =
189+
NpuOpRunner("Cast", {*tmp_variance}, {*variance},
190+
{{"dst_type", static_cast<int>(dst_dtype)}});
191+
runner_cast_variance.Run(stream);
192+
}
193+
90194
// revert shape of scale and bias
91195
// TODO(zhiqiu): better implementation, use tmp tensor to avoid write input
92196
// tensor.
@@ -99,6 +203,7 @@ template <typename T>
99203
class LayerNormGradNPUKernel : public framework::OpKernel<T> {
100204
public:
101205
void Compute(const framework::ExecutionContext& ctx) const override {
206+
using U = LayerNormParamType<T>;
102207
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
103208
const auto* x = ctx.Input<Tensor>("X");
104209
const auto& x_dims = x->dims();
@@ -156,25 +261,115 @@ class LayerNormGradNPUKernel : public framework::OpKernel<T> {
156261
const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
157262
}
158263

264+
// cast scale from LayerNormParamType to T if needed
265+
Tensor cast_scale(x->type());
266+
if (x->type() == framework::proto::VarType::FP16 &&
267+
scale->type() == framework::proto::VarType::FP32) {
268+
cast_scale.Resize(scale->dims());
269+
cast_scale.mutable_data<T>(ctx.GetPlace());
270+
auto dst_dtype = ConvertToNpuDtype(x->type());
271+
auto runner_cast_scale =
272+
NpuOpRunner("Cast", {*scale}, {cast_scale},
273+
{{"dst_type", static_cast<int>(dst_dtype)}});
274+
runner_cast_scale.Run(stream);
275+
} else {
276+
cast_scale.ShareDataWith(*scale);
277+
}
278+
279+
// cast mean from LayerNormParamType to T if needed
280+
Tensor cast_mean(x->type());
281+
if (x->type() == framework::proto::VarType::FP16 &&
282+
mean->type() == framework::proto::VarType::FP32) {
283+
cast_mean.Resize(mean->dims());
284+
cast_mean.mutable_data<T>(ctx.GetPlace());
285+
auto dst_dtype = ConvertToNpuDtype(x->type());
286+
auto runner_cast_mean =
287+
NpuOpRunner("Cast", {*mean}, {cast_mean},
288+
{{"dst_type", static_cast<int>(dst_dtype)}});
289+
runner_cast_mean.Run(stream);
290+
} else {
291+
cast_mean.ShareDataWith(*mean);
292+
}
293+
294+
// cast variance from LayerNormParamType to T if needed
295+
Tensor cast_variance(x->type());
296+
if (x->type() == framework::proto::VarType::FP16 &&
297+
variance->type() == framework::proto::VarType::FP32) {
298+
cast_variance.Resize(variance->dims());
299+
cast_variance.mutable_data<T>(ctx.GetPlace());
300+
auto dst_dtype = ConvertToNpuDtype(x->type());
301+
auto runner_cast_variance =
302+
NpuOpRunner("Cast", {*variance}, {cast_variance},
303+
{{"dst_type", static_cast<int>(dst_dtype)}});
304+
runner_cast_variance.Run(stream);
305+
} else {
306+
cast_variance.ShareDataWith(*variance);
307+
}
308+
159309
Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type());
160310
dx = (dx == nullptr) ? &dx_ : dx;
161311
dscale = (dscale == nullptr) ? &dscale_ : dscale;
162312
dbias = (dbias == nullptr) ? &dbias_ : dbias;
163313

314+
dx->Resize(x->dims());
315+
dx->mutable_data<T>(ctx.GetPlace());
316+
164317
dscale->Resize(framework::make_ddim(axes));
165-
dscale->mutable_data<T>(ctx.GetPlace());
166318

167319
dbias->Resize(framework::make_ddim(axes));
168-
dbias->mutable_data<T>(ctx.GetPlace());
169320

170-
dx->Resize(x->dims());
171-
dx->mutable_data<T>(ctx.GetPlace());
321+
// dscale should be of U type
322+
Tensor* tmp_dscale = dscale;
323+
Tensor cast_dscale(x->type());
324+
if (x->type() == framework::proto::VarType::FP16 &&
325+
(mean->type() == framework::proto::VarType::FP32 ||
326+
variance->type() == framework::proto::VarType::FP32)) {
327+
cast_dscale.Resize(dscale->dims());
328+
cast_dscale.mutable_data<T>(ctx.GetPlace());
329+
tmp_dscale = &cast_dscale;
330+
dscale->mutable_data<U>(ctx.GetPlace());
331+
} else {
332+
dscale->mutable_data<T>(ctx.GetPlace());
333+
}
172334

173-
auto runner =
174-
NpuOpRunner("LayerNormGrad", {*dy, *x, *variance, *mean, *scale},
175-
{*dx, *dscale, *dbias}, {});
335+
// same for dbias
336+
Tensor* tmp_dbias = dbias;
337+
Tensor cast_dbias(x->type());
338+
if (x->type() == framework::proto::VarType::FP16 &&
339+
(mean->type() == framework::proto::VarType::FP32 ||
340+
variance->type() == framework::proto::VarType::FP32)) {
341+
cast_dbias.Resize(dbias->dims());
342+
cast_dbias.mutable_data<T>(ctx.GetPlace());
343+
tmp_dbias = &cast_dbias;
344+
dbias->mutable_data<U>(ctx.GetPlace());
345+
} else {
346+
dbias->mutable_data<T>(ctx.GetPlace());
347+
}
348+
349+
auto runner = NpuOpRunner("LayerNormGrad",
350+
{*dy, *x, cast_variance, cast_mean, cast_scale},
351+
{*dx, *tmp_dscale, *tmp_dbias}, {});
176352
runner.Run(stream);
177353

354+
// cast back from FP16 to FP32
355+
if (x->type() == framework::proto::VarType::FP16 &&
356+
dscale->type() == framework::proto::VarType::FP32) {
357+
auto dst_dtype = ConvertToNpuDtype(dscale->type());
358+
auto runner_cast_dscale =
359+
NpuOpRunner("Cast", {*tmp_dscale}, {*dscale},
360+
{{"dst_type", static_cast<int>(dst_dtype)}});
361+
runner_cast_dscale.Run(stream);
362+
}
363+
// same for dbias
364+
if (x->type() == framework::proto::VarType::FP16 &&
365+
dbias->type() == framework::proto::VarType::FP32) {
366+
auto dst_dtype = ConvertToNpuDtype(dbias->type());
367+
auto runner_cast_dbias =
368+
NpuOpRunner("Cast", {*tmp_dbias}, {*dbias},
369+
{{"dst_type", static_cast<int>(dst_dtype)}});
370+
runner_cast_dbias.Run(stream);
371+
}
372+
178373
const_cast<Tensor*>(mean)->Resize(mean_dims);
179374
const_cast<Tensor*>(variance)->Resize(mean_dims);
180375
const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));

python/paddle/fluid/tests/unittests/npu/test_layer_norm_op_npu.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,13 @@ def set_npu(self):
5050

5151
def init_dtype(self):
5252
self.dtype = np.float32
53+
self.atol = 1e-4
5354

5455
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
55-
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
56+
self.assertTrue(
57+
np.allclose(
58+
np.array(tensor).astype(np_array.dtype), np_array, atol=atol),
59+
msg)
5660

5761
def check_forward_backward(self,
5862
shape,
@@ -72,13 +76,13 @@ def test_with_place(place,
7276
scale_shape = [D]
7377

7478
np.random.seed(123)
75-
x = np.random.random_sample(x_shape).astype(np.float32)
79+
x = np.random.random_sample(x_shape).astype(self.dtype)
7680
scale = np.random.random_sample(scale_shape).astype(
7781
np.float32) if has_scale else None
7882
bias = np.random.random_sample(scale_shape).astype(
7983
np.float32) if has_bias else None
8084
y_grad = (np.random.random_sample(x_shape) *
81-
y_grad_scale).astype(np.float32)
85+
y_grad_scale).astype(self.dtype)
8286

8387
# reference forward & backward
8488
y, mean, variance = _reference_layer_norm_naive(
@@ -101,7 +105,7 @@ def test_with_place(place,
101105
for name in ground_truth:
102106
block.create_var(
103107
name=name,
104-
dtype='float32',
108+
dtype=self.dtype,
105109
shape=ground_truth[name].shape)
106110
inputs = {"X": block.var('x')}
107111
fetch_list = [
@@ -152,18 +156,18 @@ def test_with_place(place,
152156
for name in ['x', 'scale', 'bias', 'y@GRAD']
153157
},
154158
fetch_list=fetch_list)
155-
self.__assert_close(y, out[0], "y")
159+
self.__assert_close(y, out[0], "y", self.atol)
156160
self.__assert_close(mean, out[1], "mean")
157161
self.__assert_close(variance, out[2], "variance", 1e-3)
158162
self.__assert_close(x_grad, out[3], "x_grad", 1e-2)
159163
if has_scale:
160164
self.__assert_close(scale_grad,
161165
out[fetch_list.index('scale@GRAD')],
162-
"scale_grad", 1e-3)
166+
"scale_grad", 1e-2)
163167
if has_bias:
164168
self.__assert_close(bias_grad,
165169
out[fetch_list.index('bias@GRAD')],
166-
"bias_grad")
170+
"bias_grad", self.atol)
167171

168172
test_with_place(self.place, shape, begin_norm_axis)
169173

@@ -187,5 +191,13 @@ def test_check_forward_backward_with_scale_and_bias(self):
187191
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
188192

189193

194+
@unittest.skipIf(not paddle.is_compiled_with_npu(),
195+
"core is not compiled with NPU")
196+
class TestLayerNormOpFP16(TestLayerNormOp):
197+
def init_dtype(self):
198+
self.dtype = np.float16
199+
self.atol = 1e-2
200+
201+
190202
if __name__ == '__main__':
191203
unittest.main()

0 commit comments

Comments
 (0)