1- /* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2- Licensed under the Apache License, Version 2.0 (the "License");
3- you may not use this file except in compliance with the License.
4- You may obtain a copy of the License at
5- http://www.apache.org/licenses/LICENSE-2.0
6- Unless required by applicable law or agreed to in writing, software
7- distributed under the License is distributed on an "AS IS" BASIS,
8- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9- See the License for the specific language governing permissions and
10- limitations under the License. */
1+ // Copyright (c) 2024 PaddlePaddle Authors And Intel Corporation.
2+ // All Rights Reserved.
3+ //
4+ // Licensed under the Apache License, Version 2.0 (the "License");
5+ // you may not use this file except in compliance with the License.
6+ // You may obtain a copy of the License at
7+ //
8+ // http://www.apache.org/licenses/LICENSE-2.0
9+ //
10+ // Unless required by applicable law or agreed to in writing, software
11+ // distributed under the License is distributed on an "AS IS" BASIS,
12+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ // See the License for the specific language governing permissions and
14+ // limitations under the License.
15+
1116#include < immintrin.h>
1217#include < math.h>
1318#include < omp.h>
@@ -24,15 +29,15 @@ namespace phi {
2429namespace fusion {
2530
2631template <typename T, typename Context>
27- void RmsNormXftKernel (const Context& dev_ctx,
28- const DenseTensor& x,
29- const paddle::optional<DenseTensor>& residual,
30- const DenseTensor& norm_weight,
31- const float epsilon,
32- const int begin_norm_axis,
33- DenseTensor* out,
34- DenseTensor* residual_out) {
35- const float * x_data = x.data <float >();
32+ void RmsNormKernel (const Context& dev_ctx,
33+ const DenseTensor& x,
34+ const paddle::optional<DenseTensor>& residual,
35+ const DenseTensor& norm_weight,
36+ const float epsilon,
37+ const int begin_norm_axis,
38+ DenseTensor* out,
39+ DenseTensor* residual_out) {
40+ const T * x_data = x.data <T >();
3641 T* out_data = dev_ctx.template Alloc <T>(out);
3742 const T* norm_weight_data = norm_weight.data <T>();
3843 // x(batch_size,seq_len,hidden_size)
@@ -61,7 +66,7 @@ void RmsNormXftKernel(const Context& dev_ctx,
6166 T* pr_out = residual ? residual_out_data + r * ostride : nullptr ;
6267 T* py = out_data + r * ostride;
6368
64- float squareSum = 0 ;
69+ T squareSum = 0 ;
6570
6671 __m512 vsqare = _mm512_set1_ps (0 );
6772
@@ -92,7 +97,7 @@ void RmsNormXftKernel(const Context& dev_ctx,
9297 squareSum = _mm512_reduce_add_ps (vsqare);
9398
9499 // Variance
95- float var = 1 / sqrt (squareSum / size + epsilon);
100+ T var = 1 / sqrt (squareSum / size + epsilon);
96101 __m512 vvar = _mm512_set1_ps (var);
97102
98103 for (col = 0 ; col + 15 < size; col += 16 ) {
@@ -122,4 +127,4 @@ void RmsNormXftKernel(const Context& dev_ctx,
122127} // namespace phi
123128
124129PD_REGISTER_KERNEL (
125- rms_norm_xft , CPU, ALL_LAYOUT, phi::fusion::RmsNormXftKernel , float ) {}
130+ rms_norm_avx , CPU, ALL_LAYOUT, phi::fusion::RmsNormKernel , float , double ) {}
0 commit comments