@@ -29,18 +29,26 @@ namespace phi {
2929namespace fusion {
3030
3131template <typename T, typename Context>
32- void RmsNormKernel (const Context& dev_ctx,
33- const DenseTensor& x,
34- const paddle::optional<DenseTensor>& residual,
35- const DenseTensor& norm_weight,
36- const float epsilon,
37- const int begin_norm_axis,
38- DenseTensor* out,
39- DenseTensor* residual_out) {
32+ void RmsNormAvxKernel (const Context& dev_ctx,
33+ const DenseTensor& x,
34+ const paddle::optional<DenseTensor>& bias,
35+ const paddle::optional<DenseTensor>& residual,
36+ const DenseTensor& norm_weight,
37+ const paddle::optional<DenseTensor>& norm_bias,
38+ const float epsilon,
39+ const int begin_norm_axis,
40+ const float quant_scale,
41+ const int quant_round_type,
42+ const float quant_max_bound,
43+ const float quant_min_bound,
44+ DenseTensor* out,
45+ DenseTensor* residual_out,
46+ DenseTensor* inv_var) {
47+ if (quant_scale > 0 .0f ) {
48+ PD_THROW (" NOT supported quant int8. " );
49+ }
50+
4051 const T* x_data = x.data <T>();
41- T* out_data = dev_ctx.template Alloc <T>(out);
42- const T* norm_weight_data = norm_weight.data <T>();
43- // x(batch_size,seq_len,hidden_size)
4452 int32_t rows = 1 ;
4553 int32_t cols = 1 ;
4654 for (int i = 0 ; i < begin_norm_axis; i++) {
@@ -53,10 +61,16 @@ void RmsNormKernel(const Context& dev_ctx,
5361 int size = cols;
5462 auto istride = cols;
5563 auto ostride = cols;
64+ const T* norm_weight_data = norm_weight.data <T>();
65+ const T* norm_bias_data = norm_bias ? norm_bias.get ().data <T>() : nullptr ;
5666 const T* residual_data = residual ? residual.get ().data <T>() : nullptr ;
67+ const T* bias_data = bias ? bias.get ().data <T>() : nullptr ;
68+ T* out_data = dev_ctx.template Alloc <T>(out);
5769 T* residual_out_data =
5870 residual ? dev_ctx.template Alloc <T>(residual_out) : nullptr ;
5971
72+ __m512 vb = _mm512_setzero_ps ();
73+ const T* pb = bias_data;
6074#ifdef PADDLE_WITH_MKLML
6175#pragma omp parallel for
6276#endif
@@ -77,6 +91,10 @@ void RmsNormKernel(const Context& dev_ctx,
7791 if (residual) {
7892 __m512 residual_vx = _mm512_loadu_ps (pr + col);
7993 vx = _mm512_add_ps (vx, residual_vx);
94+ if (bias) {
95+ __m512 vb = _mm512_loadu_ps (pb + col);
96+ vx = _mm512_add_ps (vx, vb);
97+ }
8098 _mm512_storeu_ps (pr_out + col, vx);
8199 }
82100 __m512 tmp = _mm512_mul_ps (vx, vx);
@@ -88,6 +106,10 @@ void RmsNormKernel(const Context& dev_ctx,
88106 if (residual) {
89107 __m512 residual_vx = _mm512_maskz_loadu_ps (mask, pr + col);
90108 vx = _mm512_mask_add_ps (vx, mask, vx, residual_vx);
109+ if (bias) {
110+ __m512 vb = _mm512_maskz_loadu_ps (mask, pb + col);
111+ vx = _mm512_mask_add_ps (vx, mask, vx, vb);
112+ }
91113 _mm512_mask_storeu_ps (pr_out + col, mask, vx);
92114 }
93115 __m512 tmp = _mm512_mul_ps (vx, vx);
@@ -105,9 +127,16 @@ void RmsNormKernel(const Context& dev_ctx,
105127 if (residual) {
106128 __m512 residual_vx = _mm512_loadu_ps (pr + col);
107129 vx = _mm512_add_ps (vx, residual_vx);
130+ if (bias) {
131+ __m512 vb = _mm512_loadu_ps (pb + col);
132+ vx = _mm512_add_ps (vx, vb);
133+ }
108134 }
109135 __m512 vw = _mm512_loadu_ps (norm_weight_data + col);
110- __m512 vy = vx * vvar * vw;
136+ if (norm_bias_data) {
137+ vb = _mm512_loadu_ps (norm_bias_data + col);
138+ }
139+ __m512 vy = vx * vvar * vw + vb;
111140 _mm512_storeu_ps (py + col, vy);
112141 }
113142 if (col < size) {
@@ -116,9 +145,16 @@ void RmsNormKernel(const Context& dev_ctx,
116145 if (residual) {
117146 __m512 residual_vx = _mm512_maskz_loadu_ps (mask, pr + col);
118147 vx = _mm512_mask_add_ps (vx, mask, vx, residual_vx);
148+ if (bias) {
149+ __m512 vb = _mm512_maskz_loadu_ps (mask, pb + col);
150+ vx = _mm512_mask_add_ps (vx, mask, vx, vb);
151+ }
119152 }
120153 __m512 vw = _mm512_maskz_loadu_ps (mask, norm_weight_data + col);
121- __m512 vy = vx * vvar * vw;
154+ if (norm_bias_data) {
155+ vb = _mm512_maskz_loadu_ps (mask, norm_bias_data + col);
156+ }
157+ __m512 vy = vx * vvar * vw + vb;
122158 _mm512_mask_storeu_ps (py + col, mask, vy);
123159 }
124160 } // end for rows
@@ -127,4 +163,4 @@ void RmsNormKernel(const Context& dev_ctx,
127163} // namespace phi
128164
129165PD_REGISTER_KERNEL (
130- rms_norm_avx , CPU, ALL_LAYOUT, phi::fusion::RmsNormKernel , float , double ) {}
166+ rms_norm , CPU, ALL_LAYOUT, phi::fusion::RmsNormAvxKernel , float , double ) {}
0 commit comments