1515#include < immintrin.h>
1616#include < math.h>
1717#include < omp.h>
18- #include < stdio.h>
19- #include < string.h>
20-
2118#include " paddle/phi/backends/cpu/cpu_context.h"
2219#include " paddle/phi/core/kernel_registry.h"
2320#include " paddle/phi/core/tensor_utils.h"
2421
25- #include " glog/logging.h"
26-
2722namespace phi {
2823namespace fusion {
2924
@@ -137,10 +132,9 @@ void RmsNormAvxKernel(const Context& dev_ctx,
137132 }
138133
139134 // vy = vx * vvar * vw + vb
140- __m512 vy;
141135 vx = _mm512_mul_ps (vx, vvar);
142136 vx = _mm512_mul_ps (vx, vw);
143- vy = _mm512_add_ps (vx, vb);
137+ __m512 vy = _mm512_add_ps (vx, vb);
144138 _mm512_storeu_ps (py + col, vy);
145139 }
146140 if (col < size) {
@@ -159,10 +153,9 @@ void RmsNormAvxKernel(const Context& dev_ctx,
159153 vb = _mm512_maskz_loadu_ps (mask, norm_bias_data + col);
160154 }
161155 // vx * vvar * vw + vb
162- __m512 vy;
163156 vx = _mm512_mask_mul_ps (vx, mask, vx, vvar);
164157 vx = _mm512_mask_mul_ps (vx, mask, vx, vw);
165- vy = _mm512_mask_add_ps (vy , mask, vx, vb);
158+ __m512 vy = _mm512_mask_add_ps (vx , mask, vx, vb);
166159 _mm512_mask_storeu_ps (py + col, mask, vy);
167160 }
168161 } // end for rows
0 commit comments