Skip to content

Commit a2b9724

Browse files
committed
update
1 parent 1445637 commit a2b9724

File tree

7 files changed

+316
-264
lines changed

7 files changed

+316
-264
lines changed

paddle/phi/api/yaml/ops.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,16 +2348,6 @@
23482348
intermediate : inv_var
23492349
backward : rms_norm_grad
23502350

2351-
- op : rms_norm_avx
2352-
args : (Tensor x, Tensor residual, Tensor norm_weight, float epsilon, int begin_norm_axis)
2353-
output : Tensor(out),Tensor(residual_out)
2354-
infer_meta :
2355-
func : RmsNormAvxInferMeta
2356-
kernel :
2357-
func : rms_norm_avx
2358-
data_type : x
2359-
optional : residual,residual_out
2360-
23612351
- op : rmsprop_
23622352
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
23632353
output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_outs)

paddle/phi/infermeta/multiary.cc

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3523,42 +3523,6 @@ void QuantizeLinearInferMeta(const MetaTensor& x,
35233523
}
35243524
}
35253525

3526-
void RmsNormAvxInferMeta(const MetaTensor& x,
3527-
const MetaTensor& residual,
3528-
const MetaTensor& norm_weight,
3529-
const float epsilon,
3530-
const int begin_norm_axis,
3531-
MetaTensor* out,
3532-
MetaTensor* residual_out) {
3533-
std::vector<int64_t> x_dims_vec = common::vectorize(x.dims());
3534-
auto x_dims_size = x_dims_vec.size();
3535-
3536-
size_t normalized_dims = 1;
3537-
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
3538-
normalized_dims *= x_dims_vec[i];
3539-
}
3540-
PADDLE_ENFORCE_EQ(normalized_dims,
3541-
norm_weight.dims()[0],
3542-
phi::errors::InvalidArgument(
3543-
"The normalized size of Input(X) must equal to be"
3544-
"the size of Weight, but received"
3545-
"normalized size of Input(X) is [%d], received size"
3546-
"of Weight is [%d]",
3547-
normalized_dims,
3548-
norm_weight.dims()[0]));
3549-
3550-
auto out_dims = common::make_ddim(x_dims_vec);
3551-
out->set_dims(out_dims);
3552-
out->set_dtype(x.dtype());
3553-
out->set_layout(x.layout());
3554-
out->share_lod(x);
3555-
3556-
residual_out->set_dims(out_dims);
3557-
residual_out->set_dtype(x.dtype());
3558-
residual_out->set_layout(x.layout());
3559-
residual_out->share_lod(x);
3560-
}
3561-
35623526
void RmsNormInferMeta(const MetaTensor& x,
35633527
const MetaTensor& bias,
35643528
const MetaTensor& residual,

paddle/phi/infermeta/multiary.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -951,14 +951,6 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
951951
MetaTensor* cache_kv_out,
952952
MetaTensor* beam_cache_offset_out);
953953

954-
void RmsNormAvxInferMeta(const MetaTensor& x,
955-
const MetaTensor& residual,
956-
const MetaTensor& norm_weight,
957-
const float epsilon,
958-
const int begin_norm_axis,
959-
MetaTensor* out,
960-
MetaTensor* residual_out);
961-
962954
void FullWithTensorInferMeta(const MetaTensor& shape,
963955
DataType dtype,
964956
MetaTensor* out);

paddle/phi/kernels/fusion/cpu/rms_norm_avx_kernel.cc

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
// Copyright (c) 2024 PaddlePaddle Authors And Intel Corporation.
2-
// All Rights Reserved.
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
32
//
43
// Licensed under the Apache License, Version 2.0 (the "License");
54
// you may not use this file except in compliance with the License.
@@ -29,18 +28,26 @@ namespace phi {
2928
namespace fusion {
3029

3130
template <typename T, typename Context>
32-
void RmsNormKernel(const Context& dev_ctx,
33-
const DenseTensor& x,
34-
const paddle::optional<DenseTensor>& residual,
35-
const DenseTensor& norm_weight,
36-
const float epsilon,
37-
const int begin_norm_axis,
38-
DenseTensor* out,
39-
DenseTensor* residual_out) {
31+
void RmsNormAvxKernel(const Context& dev_ctx,
32+
const DenseTensor& x,
33+
const paddle::optional<DenseTensor>& bias,
34+
const paddle::optional<DenseTensor>& residual,
35+
const DenseTensor& norm_weight,
36+
const paddle::optional<DenseTensor>& norm_bias,
37+
const float epsilon,
38+
const int begin_norm_axis,
39+
const float quant_scale,
40+
const int quant_round_type,
41+
const float quant_max_bound,
42+
const float quant_min_bound,
43+
DenseTensor* out,
44+
DenseTensor* residual_out,
45+
DenseTensor* inv_var) {
46+
if (quant_scale > 0.0f) {
47+
PD_THROW("NOT supported quant int8. ");
48+
}
49+
4050
const T* x_data = x.data<T>();
41-
T* out_data = dev_ctx.template Alloc<T>(out);
42-
const T* norm_weight_data = norm_weight.data<T>();
43-
// x(batch_size,seq_len,hidden_size)
4451
int32_t rows = 1;
4552
int32_t cols = 1;
4653
for (int i = 0; i < begin_norm_axis; i++) {
@@ -53,10 +60,16 @@ void RmsNormKernel(const Context& dev_ctx,
5360
int size = cols;
5461
auto istride = cols;
5562
auto ostride = cols;
63+
const T* norm_weight_data = norm_weight.data<T>();
64+
const T* norm_bias_data = norm_bias ? norm_bias.get().data<T>() : nullptr;
5665
const T* residual_data = residual ? residual.get().data<T>() : nullptr;
66+
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
67+
T* out_data = dev_ctx.template Alloc<T>(out);
5768
T* residual_out_data =
5869
residual ? dev_ctx.template Alloc<T>(residual_out) : nullptr;
5970

71+
__m512 vb = _mm512_setzero_ps();
72+
const T* pb = bias_data;
6073
#ifdef PADDLE_WITH_MKLML
6174
#pragma omp parallel for
6275
#endif
@@ -77,6 +90,10 @@ void RmsNormKernel(const Context& dev_ctx,
7790
if (residual) {
7891
__m512 residual_vx = _mm512_loadu_ps(pr + col);
7992
vx = _mm512_add_ps(vx, residual_vx);
93+
if (bias) {
94+
__m512 vb = _mm512_loadu_ps(pb + col);
95+
vx = _mm512_add_ps(vx, vb);
96+
}
8097
_mm512_storeu_ps(pr_out + col, vx);
8198
}
8299
__m512 tmp = _mm512_mul_ps(vx, vx);
@@ -88,6 +105,10 @@ void RmsNormKernel(const Context& dev_ctx,
88105
if (residual) {
89106
__m512 residual_vx = _mm512_maskz_loadu_ps(mask, pr + col);
90107
vx = _mm512_mask_add_ps(vx, mask, vx, residual_vx);
108+
if (bias) {
109+
__m512 vb = _mm512_maskz_loadu_ps(mask, pb + col);
110+
vx = _mm512_mask_add_ps(vx, mask, vx, vb);
111+
}
91112
_mm512_mask_storeu_ps(pr_out + col, mask, vx);
92113
}
93114
__m512 tmp = _mm512_mul_ps(vx, vx);
@@ -105,9 +126,16 @@ void RmsNormKernel(const Context& dev_ctx,
105126
if (residual) {
106127
__m512 residual_vx = _mm512_loadu_ps(pr + col);
107128
vx = _mm512_add_ps(vx, residual_vx);
129+
if (bias) {
130+
__m512 vb = _mm512_loadu_ps(pb + col);
131+
vx = _mm512_add_ps(vx, vb);
132+
}
108133
}
109134
__m512 vw = _mm512_loadu_ps(norm_weight_data + col);
110-
__m512 vy = vx * vvar * vw;
135+
if (norm_bias_data) {
136+
vb = _mm512_loadu_ps(norm_bias_data + col);
137+
}
138+
__m512 vy = vx * vvar * vw + vb;
111139
_mm512_storeu_ps(py + col, vy);
112140
}
113141
if (col < size) {
@@ -116,9 +144,16 @@ void RmsNormKernel(const Context& dev_ctx,
116144
if (residual) {
117145
__m512 residual_vx = _mm512_maskz_loadu_ps(mask, pr + col);
118146
vx = _mm512_mask_add_ps(vx, mask, vx, residual_vx);
147+
if (bias) {
148+
__m512 vb = _mm512_maskz_loadu_ps(mask, pb + col);
149+
vx = _mm512_mask_add_ps(vx, mask, vx, vb);
150+
}
119151
}
120152
__m512 vw = _mm512_maskz_loadu_ps(mask, norm_weight_data + col);
121-
__m512 vy = vx * vvar * vw;
153+
if (norm_bias_data) {
154+
vb = _mm512_maskz_loadu_ps(mask, norm_bias_data + col);
155+
}
156+
__m512 vy = vx * vvar * vw + vb;
122157
_mm512_mask_storeu_ps(py + col, mask, vy);
123158
}
124159
} // end for rows
@@ -127,4 +162,4 @@ void RmsNormKernel(const Context& dev_ctx,
127162
} // namespace phi
128163

129164
PD_REGISTER_KERNEL(
130-
rms_norm_avx, CPU, ALL_LAYOUT, phi::fusion::RmsNormKernel, float, double) {}
165+
rms_norm, CPU, ALL_LAYOUT, phi::fusion::RmsNormAvxKernel, float, double) {}

0 commit comments

Comments
 (0)