Skip to content

Commit 1445637

Browse files
committed
update
1 parent af733bc commit 1445637

File tree

8 files changed

+62
-49
lines changed

8 files changed

+62
-49
lines changed

paddle/phi/CMakeLists.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,7 @@ if(WITH_AVX
120120
AND WITH_MKL)
121121
set_source_files_properties(
122122
kernels/fusion/cpu/self_dp_attention_kernel.cc
123-
PROPERTIES COMPILE_FLAGS "-Wno-maybe-uninitialized -mfma ${AVX512F_FLAG}")
124-
set_source_files_properties(
125-
kernels/fusion/cpu/rms_norm_xft_kernel.cc
123+
kernels/fusion/cpu/rms_norm_avx_kernel.cc
126124
PROPERTIES COMPILE_FLAGS "-Wno-maybe-uninitialized -mfma ${AVX512F_FLAG}")
127125
endif()
128126

paddle/phi/api/yaml/ops.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,13 +2348,13 @@
23482348
intermediate : inv_var
23492349
backward : rms_norm_grad
23502350

2351-
- op : rms_norm_xft
2351+
- op : rms_norm_avx
23522352
args : (Tensor x, Tensor residual, Tensor norm_weight, float epsilon, int begin_norm_axis)
23532353
output : Tensor(out),Tensor(residual_out)
23542354
infer_meta :
2355-
func : RmsNormXftInferMeta
2355+
func : RmsNormAvxInferMeta
23562356
kernel :
2357-
func : rms_norm_xft
2357+
func : rms_norm_avx
23582358
data_type : x
23592359
optional : residual,residual_out
23602360

paddle/phi/infermeta/multiary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3523,7 +3523,7 @@ void QuantizeLinearInferMeta(const MetaTensor& x,
35233523
}
35243524
}
35253525

3526-
void RmsNormXftInferMeta(const MetaTensor& x,
3526+
void RmsNormAvxInferMeta(const MetaTensor& x,
35273527
const MetaTensor& residual,
35283528
const MetaTensor& norm_weight,
35293529
const float epsilon,

paddle/phi/infermeta/multiary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
951951
MetaTensor* cache_kv_out,
952952
MetaTensor* beam_cache_offset_out);
953953

954-
void RmsNormXftInferMeta(const MetaTensor& x,
954+
void RmsNormAvxInferMeta(const MetaTensor& x,
955955
const MetaTensor& residual,
956956
const MetaTensor& norm_weight,
957957
const float epsilon,

paddle/phi/kernels/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ if(NOT
263263
AND AVX512F_FLAG
264264
AND WITH_MKL))
265265
list(REMOVE_ITEM kernel_cc "fusion/cpu/self_dp_attention_kernel.cc")
266-
list(REMOVE_ITEM kernel_cc "fusion/cpu/rms_norm_xft_kernel.cc")
266+
list(REMOVE_ITEM kernel_cc "fusion/cpu/rms_norm_avx_kernel.cc")
267267
endif()
268268

269269
file(

paddle/phi/kernels/fusion/cpu/rms_norm_xft_kernel.cc renamed to paddle/phi/kernels/fusion/cpu/rms_norm_avx_kernel.cc

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2-
Licensed under the Apache License, Version 2.0 (the "License");
3-
you may not use this file except in compliance with the License.
4-
You may obtain a copy of the License at
5-
http://www.apache.org/licenses/LICENSE-2.0
6-
Unless required by applicable law or agreed to in writing, software
7-
distributed under the License is distributed on an "AS IS" BASIS,
8-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
See the License for the specific language governing permissions and
10-
limitations under the License. */
1+
// Copyright (c) 2024 PaddlePaddle Authors And Intel Corporation.
2+
// All Rights Reserved.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
1116
#include <immintrin.h>
1217
#include <math.h>
1318
#include <omp.h>
@@ -24,15 +29,15 @@ namespace phi {
2429
namespace fusion {
2530

2631
template <typename T, typename Context>
27-
void RmsNormXftKernel(const Context& dev_ctx,
28-
const DenseTensor& x,
29-
const paddle::optional<DenseTensor>& residual,
30-
const DenseTensor& norm_weight,
31-
const float epsilon,
32-
const int begin_norm_axis,
33-
DenseTensor* out,
34-
DenseTensor* residual_out) {
35-
const float* x_data = x.data<float>();
32+
void RmsNormKernel(const Context& dev_ctx,
33+
const DenseTensor& x,
34+
const paddle::optional<DenseTensor>& residual,
35+
const DenseTensor& norm_weight,
36+
const float epsilon,
37+
const int begin_norm_axis,
38+
DenseTensor* out,
39+
DenseTensor* residual_out) {
40+
const T* x_data = x.data<T>();
3641
T* out_data = dev_ctx.template Alloc<T>(out);
3742
const T* norm_weight_data = norm_weight.data<T>();
3843
// x(batch_size,seq_len,hidden_size)
@@ -61,7 +66,7 @@ void RmsNormXftKernel(const Context& dev_ctx,
6166
T* pr_out = residual ? residual_out_data + r * ostride : nullptr;
6267
T* py = out_data + r * ostride;
6368

64-
float squareSum = 0;
69+
T squareSum = 0;
6570

6671
__m512 vsqare = _mm512_set1_ps(0);
6772

@@ -92,7 +97,7 @@ void RmsNormXftKernel(const Context& dev_ctx,
9297
squareSum = _mm512_reduce_add_ps(vsqare);
9398

9499
// Variance
95-
float var = 1 / sqrt(squareSum / size + epsilon);
100+
T var = 1 / sqrt(squareSum / size + epsilon);
96101
__m512 vvar = _mm512_set1_ps(var);
97102

98103
for (col = 0; col + 15 < size; col += 16) {
@@ -122,4 +127,4 @@ void RmsNormXftKernel(const Context& dev_ctx,
122127
} // namespace phi
123128

124129
PD_REGISTER_KERNEL(
125-
rms_norm_xft, CPU, ALL_LAYOUT, phi::fusion::RmsNormXftKernel, float) {}
130+
rms_norm_avx, CPU, ALL_LAYOUT, phi::fusion::RmsNormKernel, float, double) {}

test/mkldnn/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ if(WITH_MKLDNN AND NOT WIN32)
1111
list(APPEND TEST_OPS "test_onnx_format_quantization_mobilenetv1")
1212
endif()
1313

14+
if(NOT
15+
(
16+
(WITH_AVX
17+
AND AVX512F_FOUND
18+
AND AVX512F_FLAG
19+
AND WITH_MKL)))
20+
list(REMOVE_ITEM TEST_OPS "test_rms_norm_avx_op")
21+
endif()
22+
1423
foreach(TEST_OP ${TEST_OPS})
1524
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
1625
endforeach()
Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,48 +24,44 @@
2424
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
2525

2626

27-
def rms_norm_xft(
27+
def rms_norm_avx(
2828
x,
2929
norm_weight,
3030
begin_norm_axis,
3131
residual=None,
3232
epsilon=1e-6,
3333
):
3434
if paddle.in_dynamic_mode():
35-
helper = LayerHelper('rms_norm_xft', **locals())
35+
helper = LayerHelper('rms_norm_avx', **locals())
3636
attrs = (
3737
epsilon,
3838
begin_norm_axis,
3939
)
40-
output = _C_ops.rms_norm_xft(x, residual, norm_weight, *attrs)
41-
return output
40+
output, residual = _C_ops.rms_norm_avx(x, residual, norm_weight, *attrs)
41+
return output, residual
4242
else:
43-
helper = LayerHelper("rms_norm_xft", **locals())
44-
inputs = {
45-
"x": [x],
46-
"norm_weight": [norm_weight],
47-
}
43+
helper = LayerHelper("rms_norm_avx", **locals())
4844

4945
output = helper.create_variable_for_type_inference(x.dtype)
46+
residual_out = helper.create_variable_for_type_inference(x.dtype)
5047

5148
inputs = {'x': x, 'residual': residual, 'norm_weight': norm_weight}
52-
outputs = {'out': output}
49+
outputs = {'out': output, 'residual_out': residual_out}
5350

5451
helper.append_op(
55-
type="rms_norm_xft",
52+
type="rms_norm_avx",
5653
inputs=inputs,
5754
attrs={"epsilon": epsilon, "begin_norm_axis": begin_norm_axis},
5855
outputs=outputs,
5956
)
60-
return output
57+
return output, residual
6158

6259

6360
class RmsNormXFTTestCase(unittest.TestCase):
6461
def config(self):
6562
self.dtype = 'float32'
6663
self.rtol = 1e-5
6764
self.atol = 1e-5
68-
self.bias = True
6965
self.bs = 1
7066
self.seq_len = 8
7167
self.hidden_size = 4096
@@ -90,11 +86,11 @@ def raw_rms_norm(self, x, norm_weight, residual=None, epsilon=1e-6):
9086
input_x = paddle.rsqrt(variance + epsilon) * input_x
9187
return input_x * norm_weight
9288

93-
def test_rms_norm_xft(self):
89+
def test_rms_norm_avx(self):
9490
act_out = self.raw_rms_norm(
9591
self.x, self.weight, self.residual, self.epsilon
9692
)
97-
xft_out = rms_norm_xft(
93+
xft_out = rms_norm_avx(
9894
x=self.x,
9995
residual=self.residual,
10096
norm_weight=self.weight,
@@ -125,7 +121,7 @@ def setUp(self):
125121
)
126122
self.residual = None
127123

128-
def test_rms_norm_xft(self):
124+
def test_rms_norm_avx(self):
129125
paddle.enable_static()
130126
exe = base.Executor(base.CPUPlace())
131127
main_program = base.Program()
@@ -139,22 +135,27 @@ def test_rms_norm_xft(self):
139135
norm_weight = paddle.static.data(
140136
shape=[self.hidden_size], dtype='float32', name='norm_weight'
141137
)
142-
xft_out = rms_norm_xft(
138+
xft_out, residual_out = rms_norm_avx(
143139
x=x,
144140
residual=self.residual,
145141
norm_weight=norm_weight,
146142
epsilon=self.epsilon,
147143
begin_norm_axis=self.begin_norm_axis,
148144
)
149145
exe.run(startup_program)
146+
fetch_list = (
147+
[xft_out.name]
148+
if self.residual is None
149+
else [xft_out.name, residual_out.name]
150+
)
150151
xft_res = exe.run(
151152
main_program,
152153
feed={
153154
'x': self.x.numpy(),
154155
'norm_weight': self.weight.numpy(),
155156
'residual': self.residual,
156157
},
157-
fetch_list=[xft_out.name],
158+
fetch_list=fetch_list,
158159
)
159160
paddle.disable_static()
160161
act_out = self.raw_rms_norm(

0 commit comments

Comments
 (0)