Skip to content

Commit 811ba4f

Browse files
committed
update
1 parent 1445637 commit 811ba4f

File tree

7 files changed

+315
-262
lines changed

7 files changed

+315
-262
lines changed

paddle/phi/api/yaml/ops.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,16 +2348,6 @@
23482348
intermediate : inv_var
23492349
backward : rms_norm_grad
23502350

2351-
- op : rms_norm_avx
2352-
args : (Tensor x, Tensor residual, Tensor norm_weight, float epsilon, int begin_norm_axis)
2353-
output : Tensor(out),Tensor(residual_out)
2354-
infer_meta :
2355-
func : RmsNormAvxInferMeta
2356-
kernel :
2357-
func : rms_norm_avx
2358-
data_type : x
2359-
optional : residual,residual_out
2360-
23612351
- op : rmsprop_
23622352
args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)
23632353
output : Tensor(param_out), Tensor(moment_out), Tensor(mean_square_out), Tensor(mean_grad_out), Tensor(master_param_outs)

paddle/phi/infermeta/multiary.cc

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3523,42 +3523,6 @@ void QuantizeLinearInferMeta(const MetaTensor& x,
35233523
}
35243524
}
35253525

3526-
void RmsNormAvxInferMeta(const MetaTensor& x,
3527-
const MetaTensor& residual,
3528-
const MetaTensor& norm_weight,
3529-
const float epsilon,
3530-
const int begin_norm_axis,
3531-
MetaTensor* out,
3532-
MetaTensor* residual_out) {
3533-
std::vector<int64_t> x_dims_vec = common::vectorize(x.dims());
3534-
auto x_dims_size = x_dims_vec.size();
3535-
3536-
size_t normalized_dims = 1;
3537-
for (size_t i = begin_norm_axis; i < x_dims_size; ++i) {
3538-
normalized_dims *= x_dims_vec[i];
3539-
}
3540-
PADDLE_ENFORCE_EQ(normalized_dims,
3541-
norm_weight.dims()[0],
3542-
phi::errors::InvalidArgument(
3543-
"The normalized size of Input(X) must equal to be"
3544-
"the size of Weight, but received"
3545-
"normalized size of Input(X) is [%d], received size"
3546-
"of Weight is [%d]",
3547-
normalized_dims,
3548-
norm_weight.dims()[0]));
3549-
3550-
auto out_dims = common::make_ddim(x_dims_vec);
3551-
out->set_dims(out_dims);
3552-
out->set_dtype(x.dtype());
3553-
out->set_layout(x.layout());
3554-
out->share_lod(x);
3555-
3556-
residual_out->set_dims(out_dims);
3557-
residual_out->set_dtype(x.dtype());
3558-
residual_out->set_layout(x.layout());
3559-
residual_out->share_lod(x);
3560-
}
3561-
35623526
void RmsNormInferMeta(const MetaTensor& x,
35633527
const MetaTensor& bias,
35643528
const MetaTensor& residual,

paddle/phi/infermeta/multiary.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -951,14 +951,6 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
951951
MetaTensor* cache_kv_out,
952952
MetaTensor* beam_cache_offset_out);
953953

954-
void RmsNormAvxInferMeta(const MetaTensor& x,
955-
const MetaTensor& residual,
956-
const MetaTensor& norm_weight,
957-
const float epsilon,
958-
const int begin_norm_axis,
959-
MetaTensor* out,
960-
MetaTensor* residual_out);
961-
962954
void FullWithTensorInferMeta(const MetaTensor& shape,
963955
DataType dtype,
964956
MetaTensor* out);

paddle/phi/kernels/fusion/cpu/rms_norm_avx_kernel.cc

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,26 @@ namespace phi {
2929
namespace fusion {
3030

3131
template <typename T, typename Context>
32-
void RmsNormKernel(const Context& dev_ctx,
33-
const DenseTensor& x,
34-
const paddle::optional<DenseTensor>& residual,
35-
const DenseTensor& norm_weight,
36-
const float epsilon,
37-
const int begin_norm_axis,
38-
DenseTensor* out,
39-
DenseTensor* residual_out) {
32+
void RmsNormAvxKernel(const Context& dev_ctx,
33+
const DenseTensor& x,
34+
const paddle::optional<DenseTensor>& bias,
35+
const paddle::optional<DenseTensor>& residual,
36+
const DenseTensor& norm_weight,
37+
const paddle::optional<DenseTensor>& norm_bias,
38+
const float epsilon,
39+
const int begin_norm_axis,
40+
const float quant_scale,
41+
const int quant_round_type,
42+
const float quant_max_bound,
43+
const float quant_min_bound,
44+
DenseTensor* out,
45+
DenseTensor* residual_out,
46+
DenseTensor* inv_var) {
47+
if (quant_scale > 0.0f) {
48+
PD_THROW("NOT supported quant int8. ");
49+
}
50+
4051
const T* x_data = x.data<T>();
41-
T* out_data = dev_ctx.template Alloc<T>(out);
42-
const T* norm_weight_data = norm_weight.data<T>();
43-
// x(batch_size,seq_len,hidden_size)
4452
int32_t rows = 1;
4553
int32_t cols = 1;
4654
for (int i = 0; i < begin_norm_axis; i++) {
@@ -53,10 +61,16 @@ void RmsNormKernel(const Context& dev_ctx,
5361
int size = cols;
5462
auto istride = cols;
5563
auto ostride = cols;
64+
const T* norm_weight_data = norm_weight.data<T>();
65+
const T* norm_bias_data = norm_bias ? norm_bias.get().data<T>() : nullptr;
5666
const T* residual_data = residual ? residual.get().data<T>() : nullptr;
67+
const T* bias_data = bias ? bias.get().data<T>() : nullptr;
68+
T* out_data = dev_ctx.template Alloc<T>(out);
5769
T* residual_out_data =
5870
residual ? dev_ctx.template Alloc<T>(residual_out) : nullptr;
5971

72+
__m512 vb = _mm512_setzero_ps();
73+
const T* pb = bias_data;
6074
#ifdef PADDLE_WITH_MKLML
6175
#pragma omp parallel for
6276
#endif
@@ -77,6 +91,10 @@ void RmsNormKernel(const Context& dev_ctx,
7791
if (residual) {
7892
__m512 residual_vx = _mm512_loadu_ps(pr + col);
7993
vx = _mm512_add_ps(vx, residual_vx);
94+
if (bias) {
95+
__m512 vb = _mm512_loadu_ps(pb + col);
96+
vx = _mm512_add_ps(vx, vb);
97+
}
8098
_mm512_storeu_ps(pr_out + col, vx);
8199
}
82100
__m512 tmp = _mm512_mul_ps(vx, vx);
@@ -88,6 +106,10 @@ void RmsNormKernel(const Context& dev_ctx,
88106
if (residual) {
89107
__m512 residual_vx = _mm512_maskz_loadu_ps(mask, pr + col);
90108
vx = _mm512_mask_add_ps(vx, mask, vx, residual_vx);
109+
if (bias) {
110+
__m512 vb = _mm512_maskz_loadu_ps(mask, pb + col);
111+
vx = _mm512_mask_add_ps(vx, mask, vx, vb);
112+
}
91113
_mm512_mask_storeu_ps(pr_out + col, mask, vx);
92114
}
93115
__m512 tmp = _mm512_mul_ps(vx, vx);
@@ -105,9 +127,16 @@ void RmsNormKernel(const Context& dev_ctx,
105127
if (residual) {
106128
__m512 residual_vx = _mm512_loadu_ps(pr + col);
107129
vx = _mm512_add_ps(vx, residual_vx);
130+
if (bias) {
131+
__m512 vb = _mm512_loadu_ps(pb + col);
132+
vx = _mm512_add_ps(vx, vb);
133+
}
108134
}
109135
__m512 vw = _mm512_loadu_ps(norm_weight_data + col);
110-
__m512 vy = vx * vvar * vw;
136+
if (norm_bias_data) {
137+
vb = _mm512_loadu_ps(norm_bias_data + col);
138+
}
139+
__m512 vy = vx * vvar * vw + vb;
111140
_mm512_storeu_ps(py + col, vy);
112141
}
113142
if (col < size) {
@@ -116,9 +145,16 @@ void RmsNormKernel(const Context& dev_ctx,
116145
if (residual) {
117146
__m512 residual_vx = _mm512_maskz_loadu_ps(mask, pr + col);
118147
vx = _mm512_mask_add_ps(vx, mask, vx, residual_vx);
148+
if (bias) {
149+
__m512 vb = _mm512_maskz_loadu_ps(mask, pb + col);
150+
vx = _mm512_mask_add_ps(vx, mask, vx, vb);
151+
}
119152
}
120153
__m512 vw = _mm512_maskz_loadu_ps(mask, norm_weight_data + col);
121-
__m512 vy = vx * vvar * vw;
154+
if (norm_bias_data) {
155+
vb = _mm512_maskz_loadu_ps(mask, norm_bias_data + col);
156+
}
157+
__m512 vy = vx * vvar * vw + vb;
122158
_mm512_mask_storeu_ps(py + col, mask, vy);
123159
}
124160
} // end for rows
@@ -127,4 +163,4 @@ void RmsNormKernel(const Context& dev_ctx,
127163
} // namespace phi
128164

129165
PD_REGISTER_KERNEL(
130-
rms_norm_avx, CPU, ALL_LAYOUT, phi::fusion::RmsNormKernel, float, double) {}
166+
rms_norm, CPU, ALL_LAYOUT, phi::fusion::RmsNormAvxKernel, float, double) {}

0 commit comments

Comments
 (0)