Skip to content

Commit 8322b6b

Browse files
authored
[Inference cpu]fused_bias_residual_layernorm op support cpu (#63196)
* add fuse_layer_norm * update * update for windows
1 parent 71018f4 commit 8322b6b

File tree

7 files changed

+650
-5
lines changed

7 files changed

+650
-5
lines changed

cmake/simd.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,17 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
1111
set(AVX_FLAG "-mavx")
1212
set(AVX2_FLAG "-mavx2")
1313
set(AVX512F_FLAG "-mavx512f")
14+
set(Wno_Maybe_Uninitialized "-Wno-maybe-uninitialized")
15+
set(FMA_FLAG "-mfma")
1416
elseif(MSVC)
1517
set(MMX_FLAG "/arch:MMX")
1618
set(SSE2_FLAG "/arch:SSE2")
1719
set(SSE3_FLAG "/arch:SSE3")
1820
set(AVX_FLAG "/arch:AVX")
1921
set(AVX2_FLAG "/arch:AVX2")
22+
set(AVX512F_FLAG "/arch:AVX512")
23+
set(Wno_Maybe_Uninitialized "/wd4701")
24+
set(FMA_FLAG "/arch:AVX2")
2025
endif()
2126

2227
set(CMAKE_REQUIRED_FLAGS_RETAINED ${CMAKE_REQUIRED_FLAGS})

paddle/fluid/pybind/pybind.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,10 @@ bool SupportsInt8() {
406406
#endif
407407
}
408408

409+
bool SupportsAvx512F() {
410+
return phi::backends::cpu::MayIUse(phi::backends::cpu::cpu_isa_t::avx512f);
411+
}
412+
409413
bool SupportsVNNI() {
410414
#ifndef PADDLE_WITH_DNNL
411415
return false;
@@ -2154,6 +2158,7 @@ All parameter, weight, gradient are variables in Paddle.
21542158
m.def("supports_bfloat16", SupportsBfloat16);
21552159
m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance);
21562160
m.def("supports_int8", SupportsInt8);
2161+
m.def("supports_avx512f", SupportsAvx512F);
21572162
m.def("supports_vnni", SupportsVNNI);
21582163
m.def("op_supported_infos", imperative::OpSupportedInfos);
21592164
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);

paddle/phi/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,10 @@ if(WITH_AVX
119119
AND AVX512F_FLAG
120120
AND WITH_MKL)
121121
set_source_files_properties(
122+
kernels/fusion/cpu/fused_layer_norm_avx_kernel.cc
122123
kernels/fusion/cpu/self_dp_attention_kernel.cc
123-
PROPERTIES COMPILE_FLAGS "-Wno-maybe-uninitialized -mfma ${AVX512F_FLAG}")
124+
PROPERTIES COMPILE_FLAGS
125+
"${Wno_Maybe_Uninitialized} ${FMA_FLAG} ${AVX512F_FLAG}")
124126
endif()
125127

126128
if(WITH_GPU)

paddle/phi/kernels/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ if(NOT
262262
AND AVX512F_FOUND
263263
AND AVX512F_FLAG
264264
AND WITH_MKL))
265+
list(REMOVE_ITEM kernel_cc "fusion/cpu/fused_layer_norm_avx_kernel.cc")
265266
list(REMOVE_ITEM kernel_cc "fusion/cpu/self_dp_attention_kernel.cc")
266267
endif()
267268

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <immintrin.h>
16+
#include <math.h>
17+
#include <omp.h>
18+
#include <stdio.h>
19+
#include <string.h>
20+
21+
#include "paddle/phi/backends/cpu/cpu_context.h"
22+
#include "paddle/phi/core/kernel_registry.h"
23+
#include "paddle/phi/core/tensor_utils.h"
24+
25+
namespace phi {
26+
namespace fusion {
27+
28+
template <typename T>
29+
void ResidualBiasSumFunc(const T* x_data,
30+
const T* residual_data,
31+
const T* bias_data,
32+
const float residual_alpha,
33+
const int rows,
34+
const int cols,
35+
const int iStride,
36+
const int oStride,
37+
T* out_data) {
38+
__m512 vresidual_alpha = _mm512_set1_ps(residual_alpha);
39+
const T* pb = bias_data;
40+
#ifdef PADDLE_WITH_MKLML
41+
#pragma omp parallel for
42+
#endif
43+
for (int r = 0; r < rows; ++r) {
44+
const T* px = x_data + r * iStride;
45+
const T* pr = residual_data ? residual_data + r * iStride : nullptr;
46+
T* py = out_data + r * oStride;
47+
for (int col = 0; col < cols; col += 16) {
48+
int remain = cols - col;
49+
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
50+
51+
// residual*alpha + bias + x
52+
__m512 vx = _mm512_maskz_loadu_ps(mask, px + col);
53+
if (residual_data) {
54+
__m512 residual_vx = _mm512_maskz_loadu_ps(mask, pr + col);
55+
residual_vx = _mm512_mul_ps(residual_vx, vresidual_alpha);
56+
vx = _mm512_mask_add_ps(vx, mask, vx, residual_vx);
57+
}
58+
if (bias_data) {
59+
__m512 vb = _mm512_maskz_loadu_ps(mask, pb + col);
60+
vx = _mm512_mask_add_ps(vx, mask, vx, vb);
61+
}
62+
_mm512_mask_storeu_ps(py + col, mask, vx);
63+
}
64+
}
65+
}
66+
67+
template <typename T>
68+
void LayerNormFunc(const T* x_data,
69+
const T* residual_data,
70+
const T* bias_data,
71+
const T* norm_weight_data,
72+
const T* norm_bias_data,
73+
const float epsilon,
74+
const float residual_alpha,
75+
const int rows,
76+
const int cols,
77+
const int iStride,
78+
const int oStride,
79+
T* out_data,
80+
T* residual_out_data,
81+
T* mean_out,
82+
T* var_out) {
83+
auto size = cols;
84+
__m512 vresidual_alpha = _mm512_set1_ps(residual_alpha);
85+
__m512 vgamma = _mm512_set1_ps(1);
86+
__m512 vbeta = _mm512_set1_ps(0);
87+
const T* pb = bias_data;
88+
#ifdef PADDLE_WITH_MKLML
89+
#pragma omp parallel for
90+
#endif
91+
for (int r = 0; r < rows; ++r) {
92+
const T* px = x_data + r * iStride;
93+
const T* pr = residual_data ? residual_data + r * iStride : nullptr;
94+
T* pr_out = residual_out_data ? residual_out_data + r * oStride : nullptr;
95+
T* py = out_data + r * oStride;
96+
97+
T sum = 0;
98+
T squareSum = 0;
99+
100+
__m512 vsum = _mm512_set1_ps(0);
101+
__m512 vsqare = _mm512_set1_ps(0);
102+
for (int col = 0; col < size; col += 16) {
103+
int remain = size - col;
104+
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
105+
106+
// SUM(x)
107+
__m512 vx = _mm512_maskz_loadu_ps(mask, px + col);
108+
if (residual_data) {
109+
__m512 residual_vx = _mm512_maskz_loadu_ps(mask, pr + col);
110+
residual_vx = _mm512_mul_ps(residual_vx, vresidual_alpha);
111+
vx = _mm512_mask_add_ps(vx, mask, vx, residual_vx);
112+
if (bias_data) {
113+
__m512 vb = _mm512_maskz_loadu_ps(mask, pb + col);
114+
vx = _mm512_mask_add_ps(vx, mask, vx, vb);
115+
}
116+
_mm512_mask_storeu_ps(pr_out + col, mask, vx);
117+
}
118+
vsum = _mm512_add_ps(vsum, vx);
119+
120+
// SUM(x*x)
121+
__m512 tmp = _mm512_mul_ps(vx, vx);
122+
vsqare = _mm512_add_ps(vsqare, tmp);
123+
}
124+
125+
sum = _mm512_reduce_add_ps(vsum);
126+
squareSum = _mm512_reduce_add_ps(vsqare);
127+
128+
// Mean
129+
T mean = sum / size;
130+
mean_out[r] = mean;
131+
__m512 vmean = _mm512_set1_ps(mean);
132+
133+
// Variance
134+
T var = 1 / sqrt(squareSum / size - mean * mean + epsilon);
135+
var_out[r] = var;
136+
__m512 vvar = _mm512_set1_ps(var);
137+
138+
for (int col = 0; col < size; col += 16) {
139+
int remain = size - col;
140+
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
141+
142+
__m512 vx = _mm512_maskz_loadu_ps(mask, px + col);
143+
if (residual_data) {
144+
__m512 residual_vx = _mm512_maskz_loadu_ps(mask, pr + col);
145+
residual_vx = _mm512_mul_ps(residual_vx, vresidual_alpha);
146+
vx = _mm512_mask_add_ps(vx, mask, vx, residual_vx);
147+
if (bias_data) {
148+
__m512 vb = _mm512_maskz_loadu_ps(mask, pb + col);
149+
vx = _mm512_mask_add_ps(vx, mask, vx, vb);
150+
}
151+
}
152+
if (norm_weight_data) {
153+
vgamma = _mm512_maskz_loadu_ps(mask, norm_weight_data + col);
154+
}
155+
if (norm_bias_data) {
156+
vbeta = _mm512_maskz_loadu_ps(mask, norm_bias_data + col);
157+
}
158+
// (vx - vmean) * vgamma * vvar + vbeta
159+
__m512 vy;
160+
vx = _mm512_mask_sub_ps(vx, mask, vx, vmean);
161+
vx = _mm512_mask_mul_ps(vx, mask, vx, vgamma);
162+
vx = _mm512_mask_mul_ps(vx, mask, vx, vvar);
163+
vy = _mm512_mask_add_ps(vy, mask, vx, vbeta);
164+
_mm512_mask_storeu_ps(py + col, mask, vy);
165+
}
166+
}
167+
}
168+
169+
template <typename T, typename Context>
170+
void FusedLayerNormAvxKernel(const Context& dev_ctx,
171+
const DenseTensor& x,
172+
const paddle::optional<DenseTensor>& bias,
173+
const paddle::optional<DenseTensor>& residual,
174+
const paddle::optional<DenseTensor>& norm_weight,
175+
const paddle::optional<DenseTensor>& norm_bias,
176+
const float epsilon,
177+
const float residual_alpha,
178+
const int begin_norm_axis,
179+
const float quant_scale,
180+
const int quant_round_type,
181+
const float quant_max_bound,
182+
const float quant_min_bound,
183+
DenseTensor* out,
184+
DenseTensor* residual_out,
185+
DenseTensor* mean,
186+
DenseTensor* variance) {
187+
if (quant_scale > 0.0f) {
188+
PD_THROW("NOT supported quant int8. ");
189+
}
190+
const auto x_dims = x.dims();
191+
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
192+
T* out_data = dev_ctx.template Alloc<T>(out);
193+
T* mean_out = dev_ctx.template Alloc<T>(mean);
194+
T* var_out = dev_ctx.template Alloc<T>(variance);
195+
196+
const T* x_data = x.data<T>();
197+
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
198+
const T* residual_data = residual ? residual.get().data<T>() : nullptr;
199+
const T* norm_weight_data =
200+
norm_weight ? norm_weight.get().data<T>() : nullptr;
201+
const T* norm_bias_data = norm_bias ? norm_bias.get().data<T>() : nullptr;
202+
T* residual_out_data =
203+
residual ? dev_ctx.template Alloc<T>(residual_out) : nullptr;
204+
205+
int32_t rows = static_cast<int32_t>(matrix_dim[0]);
206+
int32_t cols = static_cast<int32_t>(matrix_dim[1]);
207+
208+
auto iStride = cols;
209+
auto oStride = cols;
210+
if (!norm_weight && !norm_bias_data) {
211+
ResidualBiasSumFunc(x_data,
212+
residual_data,
213+
bias_data,
214+
residual_alpha,
215+
rows,
216+
cols,
217+
iStride,
218+
oStride,
219+
out_data);
220+
} else {
221+
LayerNormFunc(x_data,
222+
residual_data,
223+
bias_data,
224+
norm_weight_data,
225+
norm_bias_data,
226+
epsilon,
227+
residual_alpha,
228+
rows,
229+
cols,
230+
iStride,
231+
oStride,
232+
out_data,
233+
residual_out_data,
234+
mean_out,
235+
var_out);
236+
}
237+
}
238+
} // namespace fusion
239+
} // namespace phi
240+
241+
PD_REGISTER_KERNEL(fused_bias_residual_layernorm,
242+
CPU,
243+
ALL_LAYOUT,
244+
phi::fusion::FusedLayerNormAvxKernel,
245+
float) {}

paddle/phi/kernels/fusion/cpu/self_dp_attention_kernel.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ void softmax_sum_max(float* AB,
257257
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
258258

259259
__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
260-
vx = vexp(vx * vrefac - vmax);
260+
vx = _mm512_mask_mul_ps(vx, mask, vx, vrefac);
261+
vx = _mm512_mask_sub_ps(vx, mask, vx, vmax);
262+
vx = vexp(vx);
261263

262264
_mm512_mask_storeu_ps(buf + off, mask, vx);
263265

@@ -275,8 +277,7 @@ void softmax_sum_max(float* AB,
275277
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
276278

277279
__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
278-
vx = vx * vrsum;
279-
280+
vx = _mm512_mask_mul_ps(vx, mask, vx, vrsum);
280281
_mm512_mask_storeu_ps(buf + off, mask, vx);
281282
}
282283
}
@@ -301,7 +302,10 @@ void update_out_blk(float* output,
301302
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
302303
__m512 vout = _mm512_maskz_loadu_ps(mask, outbuf + off);
303304
__m512 vabc = _mm512_maskz_loadu_ps(mask, buf + off);
304-
__m512 vupt = vout * merr * vfac + vabc;
305+
vout = _mm512_mask_mul_ps(vout, mask, vout, merr);
306+
vout = _mm512_mask_mul_ps(vout, mask, vout, vfac);
307+
__m512 vupt;
308+
vupt = _mm512_mask_add_ps(vupt, mask, vout, vabc);
305309
_mm512_mask_storeu_ps(outbuf + off, mask, vupt);
306310
}
307311
pre_sum[i] = sum[i];

0 commit comments

Comments
 (0)