Skip to content

Commit ecc01b8

Browse files
committed
code style and bug fix
1 parent 0ed9f34 commit ecc01b8

File tree

1 file changed

+2
-9
lines changed

1 file changed

+2
-9
lines changed

paddle/phi/kernels/fusion/cpu/rms_norm_avx_kernel.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,10 @@
1515
#include <immintrin.h>
1616
#include <math.h>
1717
#include <omp.h>
18-
#include <stdio.h>
19-
#include <string.h>
20-
2118
#include "paddle/phi/backends/cpu/cpu_context.h"
2219
#include "paddle/phi/core/kernel_registry.h"
2320
#include "paddle/phi/core/tensor_utils.h"
2421

25-
#include "glog/logging.h"
26-
2722
namespace phi {
2823
namespace fusion {
2924

@@ -137,10 +132,9 @@ void RmsNormAvxKernel(const Context& dev_ctx,
137132
}
138133

139134
// vy = vx * vvar * vw + vb
140-
__m512 vy;
141135
vx = _mm512_mul_ps(vx, vvar);
142136
vx = _mm512_mul_ps(vx, vw);
143-
vy = _mm512_add_ps(vx, vb);
137+
__m512 vy = _mm512_add_ps(vx, vb);
144138
_mm512_storeu_ps(py + col, vy);
145139
}
146140
if (col < size) {
@@ -159,10 +153,9 @@ void RmsNormAvxKernel(const Context& dev_ctx,
159153
vb = _mm512_maskz_loadu_ps(mask, norm_bias_data + col);
160154
}
161155
// vx * vvar * vw + vb
162-
__m512 vy;
163156
vx = _mm512_mask_mul_ps(vx, mask, vx, vvar);
164157
vx = _mm512_mask_mul_ps(vx, mask, vx, vw);
165-
vy = _mm512_mask_add_ps(vy, mask, vx, vb);
158+
__m512 vy = _mm512_mask_add_ps(vx, mask, vx, vb);
166159
_mm512_mask_storeu_ps(py + col, mask, vy);
167160
}
168161
} // end for rows

0 commit comments

Comments
 (0)