@@ -21,7 +21,8 @@ namespace lite {
2121namespace kernels {
2222namespace xpu {
2323
24- void LayerNormCompute::Run () {
24+ template <typename InType, PrecisionType PType>
25+ void LayerNormCompute<InType, PType>::Run() {
2526 auto & param = this ->template Param <param_t >();
2627 auto & ctx = this ->ctx_ ->template As <XPUContext>();
2728
@@ -30,16 +31,17 @@ void LayerNormCompute::Run() {
3031 auto matrix_dim = x_dims.Flatten2D (axis);
3132 float epsilon = param.epsilon ;
3233
33- int r = xdnn::layer_norm (ctx.GetRawContext (), /* context */
34- param.X ->data <float >(), /* in */
35- param.Y ->mutable_data <float >(TARGET (kXPU )), /* out */
36- matrix_dim[0 ], /* m */
37- matrix_dim[1 ], /* n */
38- epsilon, /* epsilon */
39- param.Scale ->data <float >(), /* scale */
40- param.Bias ->data <float >(), /* bias */
41- nullptr ,
42- nullptr );
34+ int r = xdnn::layer_norm<InType>(
35+ ctx.GetRawContext (), /* context */
36+ param.X ->template data <InType>(), /* in */
37+ param.Y ->template mutable_data <InType>(TARGET (kXPU )), /* out */
38+ matrix_dim[0 ], /* m */
39+ matrix_dim[1 ], /* n */
40+ epsilon, /* epsilon */
41+ param.Scale ->template data <float >(), /* scale */
42+ param.Bias ->template data <float >(), /* bias */
43+ nullptr ,
44+ nullptr );
4345
4446 CHECK_EQ (r, 0 );
4547}
@@ -49,16 +51,25 @@ void LayerNormCompute::Run() {
4951} // namespace lite
5052} // namespace paddle
5153
52- REGISTER_LITE_KERNEL (layer_norm,
53- kXPU ,
54- kFloat ,
55- kNCHW ,
56- paddle::lite::kernels::xpu::LayerNormCompute,
57- def)
54+ namespace xpu = paddle::lite::kernels::xpu;
55+
56+ using LayerNorm_FP32 = xpu::LayerNormCompute<float , PRECISION(kFloat )>;
57+ using LayerNorm_FP16 = xpu::LayerNormCompute<float16, PRECISION(kFP16 )>;
58+ REGISTER_LITE_KERNEL (layer_norm, kXPU , kFloat , kNCHW , LayerNorm_FP32, def)
5859 .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ))})
5960 .BindInput(" Scale" , {LiteType::GetTensorTy (TARGET (kXPU ))})
6061 .BindInput(" Bias" , {LiteType::GetTensorTy (TARGET (kXPU ))})
6162 .BindOutput(" Y" , {LiteType::GetTensorTy (TARGET (kXPU ))})
6263 .BindOutput(" Mean" , {LiteType::GetTensorTy (TARGET (kXPU ))})
6364 .BindOutput(" Variance" , {LiteType::GetTensorTy (TARGET (kXPU ))})
6465 .Finalize();
66+
67+ REGISTER_LITE_KERNEL (layer_norm, kXPU , kFP16 , kNCHW , LayerNorm_FP16, fp16)
68+ .BindInput(" X" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
69+ .BindInput(" Scale" , {LiteType::GetTensorTy (TARGET (kXPU ))})
70+ .BindInput(" Bias" , {LiteType::GetTensorTy (TARGET (kXPU ))})
71+ .BindOutput(" Y" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
72+ .BindOutput(" Mean" , {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
73+ .BindOutput(" Variance" ,
74+ {LiteType::GetTensorTy (TARGET (kXPU ), PRECISION (kFP16 ))})
75+ .Finalize();
0 commit comments