1- // Copyright (c) 2024 PaddlePaddle Authors And Intel Corporation.
2- // All Rights Reserved.
1+ // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
32//
43// Licensed under the Apache License, Version 2.0 (the "License");
54// you may not use this file except in compliance with the License.
@@ -29,18 +28,26 @@ namespace phi {
2928namespace fusion {
3029
3130template <typename T, typename Context>
32- void RmsNormKernel (const Context& dev_ctx,
33- const DenseTensor& x,
34- const paddle::optional<DenseTensor>& residual,
35- const DenseTensor& norm_weight,
36- const float epsilon,
37- const int begin_norm_axis,
38- DenseTensor* out,
39- DenseTensor* residual_out) {
31+ void RmsNormAvxKernel (const Context& dev_ctx,
32+ const DenseTensor& x,
33+ const paddle::optional<DenseTensor>& bias,
34+ const paddle::optional<DenseTensor>& residual,
35+ const DenseTensor& norm_weight,
36+ const paddle::optional<DenseTensor>& norm_bias,
37+ const float epsilon,
38+ const int begin_norm_axis,
39+ const float quant_scale,
40+ const int quant_round_type,
41+ const float quant_max_bound,
42+ const float quant_min_bound,
43+ DenseTensor* out,
44+ DenseTensor* residual_out,
45+ DenseTensor* inv_var) {
46+ if (quant_scale > 0 .0f ) {
47+ PD_THROW (" NOT supported quant int8. " );
48+ }
49+
4050 const T* x_data = x.data <T>();
41- T* out_data = dev_ctx.template Alloc <T>(out);
42- const T* norm_weight_data = norm_weight.data <T>();
43- // x(batch_size,seq_len,hidden_size)
4451 int32_t rows = 1 ;
4552 int32_t cols = 1 ;
4653 for (int i = 0 ; i < begin_norm_axis; i++) {
@@ -53,10 +60,16 @@ void RmsNormKernel(const Context& dev_ctx,
5360 int size = cols;
5461 auto istride = cols;
5562 auto ostride = cols;
63+ const T* norm_weight_data = norm_weight.data <T>();
64+ const T* norm_bias_data = norm_bias ? norm_bias.get ().data <T>() : nullptr ;
5665 const T* residual_data = residual ? residual.get ().data <T>() : nullptr ;
66+ const T* bias_data = bias ? bias.get ().data <T>() : nullptr ;
67+ T* out_data = dev_ctx.template Alloc <T>(out);
5768 T* residual_out_data =
5869 residual ? dev_ctx.template Alloc <T>(residual_out) : nullptr ;
5970
71+ __m512 vb = _mm512_setzero_ps ();
72+ const T* pb = bias_data;
6073#ifdef PADDLE_WITH_MKLML
6174#pragma omp parallel for
6275#endif
@@ -77,6 +90,10 @@ void RmsNormKernel(const Context& dev_ctx,
7790 if (residual) {
7891 __m512 residual_vx = _mm512_loadu_ps (pr + col);
7992 vx = _mm512_add_ps (vx, residual_vx);
93+ if (bias) {
94+ __m512 vb = _mm512_loadu_ps (pb + col);
95+ vx = _mm512_add_ps (vx, vb);
96+ }
8097 _mm512_storeu_ps (pr_out + col, vx);
8198 }
8299 __m512 tmp = _mm512_mul_ps (vx, vx);
@@ -88,6 +105,10 @@ void RmsNormKernel(const Context& dev_ctx,
88105 if (residual) {
89106 __m512 residual_vx = _mm512_maskz_loadu_ps (mask, pr + col);
90107 vx = _mm512_mask_add_ps (vx, mask, vx, residual_vx);
108+ if (bias) {
109+ __m512 vb = _mm512_maskz_loadu_ps (mask, pb + col);
110+ vx = _mm512_mask_add_ps (vx, mask, vx, vb);
111+ }
91112 _mm512_mask_storeu_ps (pr_out + col, mask, vx);
92113 }
93114 __m512 tmp = _mm512_mul_ps (vx, vx);
@@ -105,9 +126,16 @@ void RmsNormKernel(const Context& dev_ctx,
105126 if (residual) {
106127 __m512 residual_vx = _mm512_loadu_ps (pr + col);
107128 vx = _mm512_add_ps (vx, residual_vx);
129+ if (bias) {
130+ __m512 vb = _mm512_loadu_ps (pb + col);
131+ vx = _mm512_add_ps (vx, vb);
132+ }
108133 }
109134 __m512 vw = _mm512_loadu_ps (norm_weight_data + col);
110- __m512 vy = vx * vvar * vw;
135+ if (norm_bias_data) {
136+ vb = _mm512_loadu_ps (norm_bias_data + col);
137+ }
138+ __m512 vy = vx * vvar * vw + vb;
111139 _mm512_storeu_ps (py + col, vy);
112140 }
113141 if (col < size) {
@@ -116,9 +144,16 @@ void RmsNormKernel(const Context& dev_ctx,
116144 if (residual) {
117145 __m512 residual_vx = _mm512_maskz_loadu_ps (mask, pr + col);
118146 vx = _mm512_mask_add_ps (vx, mask, vx, residual_vx);
147+ if (bias) {
148+ __m512 vb = _mm512_maskz_loadu_ps (mask, pb + col);
149+ vx = _mm512_mask_add_ps (vx, mask, vx, vb);
150+ }
119151 }
120152 __m512 vw = _mm512_maskz_loadu_ps (mask, norm_weight_data + col);
121- __m512 vy = vx * vvar * vw;
153+ if (norm_bias_data) {
154+ vb = _mm512_maskz_loadu_ps (mask, norm_bias_data + col);
155+ }
156+ __m512 vy = vx * vvar * vw + vb;
122157 _mm512_mask_storeu_ps (py + col, mask, vy);
123158 }
124159 } // end for rows
@@ -127,4 +162,4 @@ void RmsNormKernel(const Context& dev_ctx,
127162} // namespace phi
128163
129164PD_REGISTER_KERNEL (
130- rms_norm_avx , CPU, ALL_LAYOUT, phi::fusion::RmsNormKernel , float , double ) {}
165+ rms_norm , CPU, ALL_LAYOUT, phi::fusion::RmsNormAvxKernel , float , double ) {}
0 commit comments