@@ -14,6 +14,8 @@ limitations under the License. */
1414
1515#pragma once
1616#include < algorithm>
17+ #include < array>
18+ #include < numeric>
1719#include < string>
1820#include " paddle/fluid/framework/data_layout.h"
1921#include " paddle/fluid/framework/eigen.h"
@@ -73,6 +75,11 @@ class GroupNormKernel : public framework::OpKernel<T> {
7375 auto * iter_y_data = y_data;
7476 for (int bid = 0 ; bid < x_dims[0 ]; bid++) {
7577 for (int gid = 0 ; gid < groups; gid++) {
78+ const int64_t M = 8 ;
79+ std::array<T, M> x_mean_arr;
80+ std::array<T, M> x_var_arr;
81+ std::fill (x_mean_arr.begin (), x_mean_arr.end (), T (0 ));
82+ std::fill (x_var_arr.begin (), x_var_arr.end (), T (0 ));
7683 T x_mean = 0 , x_var = 0 ;
7784 int number =
7885 std::min (group_size, static_cast <int >(C - gid * group_size));
@@ -83,15 +90,75 @@ class GroupNormKernel : public framework::OpKernel<T> {
8390
8491 if (data_layout == DataLayout::kNCHW ) {
8592 for (int cid = 0 ; cid < number; cid++) {
86- for (int imid = 0 ; imid < imsize; imid++, iter_x_data++) {
93+ int imid;
94+ for (imid = 0 ; imid < imsize - (imsize % M);
95+ imid += M, iter_x_data += M) {
96+ // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used
97+ // in template class/function, before we complete high
98+ // performance cpu vector extension, temporarily unrolling
99+ // loop to get high precision and performance
100+ x_mean_arr[0 ] += iter_x_data[0 ];
101+ x_var_arr[0 ] += iter_x_data[0 ] * iter_x_data[0 ];
102+ x_mean_arr[1 ] += iter_x_data[1 ];
103+ x_var_arr[1 ] += iter_x_data[1 ] * iter_x_data[1 ];
104+ x_mean_arr[2 ] += iter_x_data[2 ];
105+ x_var_arr[2 ] += iter_x_data[2 ] * iter_x_data[2 ];
106+ x_mean_arr[3 ] += iter_x_data[3 ];
107+ x_var_arr[3 ] += iter_x_data[3 ] * iter_x_data[3 ];
108+ x_mean_arr[4 ] += iter_x_data[4 ];
109+ x_var_arr[4 ] += iter_x_data[4 ] * iter_x_data[4 ];
110+ x_mean_arr[5 ] += iter_x_data[5 ];
111+ x_var_arr[5 ] += iter_x_data[5 ] * iter_x_data[5 ];
112+ x_mean_arr[6 ] += iter_x_data[6 ];
113+ x_var_arr[6 ] += iter_x_data[6 ] * iter_x_data[6 ];
114+ x_mean_arr[7 ] += iter_x_data[7 ];
115+ x_var_arr[7 ] += iter_x_data[7 ] * iter_x_data[7 ];
116+ }
117+ x_mean =
118+ std::accumulate (x_mean_arr.cbegin (), x_mean_arr.cend (), x_mean);
119+ x_var =
120+ std::accumulate (x_var_arr.cbegin (), x_var_arr.cend (), x_var);
121+ std::fill (x_mean_arr.begin (), x_mean_arr.end (), T (0 ));
122+ std::fill (x_var_arr.begin (), x_var_arr.end (), T (0 ));
123+ for (; imid < imsize; imid++, iter_x_data++) {
87124 x_mean += iter_x_data[0 ];
88125 x_var += iter_x_data[0 ] * iter_x_data[0 ];
89126 }
90127 }
91128 } else {
92129 for (int cid = 0 ; cid < number; cid++) {
93130 iter_x_data = tmp_x + cid;
94- for (int imid = 0 ; imid < imsize; imid++, iter_x_data += C) {
131+ int imid;
132+ for (imid = 0 ; imid < imsize - (imsize % M);
133+ imid += M, iter_x_data += M * C) {
134+ // TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used
135+ // in template class/function, before we complete high
136+ // performance cpu vector extension, temporarily unrolling
137+ // loop to get high precision and performance
138+ x_mean_arr[0 ] += iter_x_data[0 * C];
139+ x_var_arr[0 ] += iter_x_data[0 * C] * iter_x_data[0 * C];
140+ x_mean_arr[1 ] += iter_x_data[1 * C];
141+ x_var_arr[1 ] += iter_x_data[1 * C] * iter_x_data[1 * C];
142+ x_mean_arr[2 ] += iter_x_data[2 * C];
143+ x_var_arr[2 ] += iter_x_data[2 * C] * iter_x_data[2 * C];
144+ x_mean_arr[3 ] += iter_x_data[3 * C];
145+ x_var_arr[3 ] += iter_x_data[3 * C] * iter_x_data[3 * C];
146+ x_mean_arr[4 ] += iter_x_data[4 * C];
147+ x_var_arr[4 ] += iter_x_data[4 * C] * iter_x_data[4 * C];
148+ x_mean_arr[5 ] += iter_x_data[5 * C];
149+ x_var_arr[5 ] += iter_x_data[5 * C] * iter_x_data[5 * C];
150+ x_mean_arr[6 ] += iter_x_data[6 * C];
151+ x_var_arr[6 ] += iter_x_data[6 * C] * iter_x_data[6 * C];
152+ x_mean_arr[7 ] += iter_x_data[7 * C];
153+ x_var_arr[7 ] += iter_x_data[7 * C] * iter_x_data[7 * C];
154+ }
155+ x_mean =
156+ std::accumulate (x_mean_arr.cbegin (), x_mean_arr.cend (), x_mean);
157+ x_var =
158+ std::accumulate (x_var_arr.cbegin (), x_var_arr.cend (), x_var);
159+ std::fill (x_mean_arr.begin (), x_mean_arr.end (), T (0 ));
160+ std::fill (x_var_arr.begin (), x_var_arr.end (), T (0 ));
161+ for (; imid < imsize; imid++, iter_x_data += C) {
95162 x_mean += iter_x_data[0 ];
96163 x_var += iter_x_data[0 ] * iter_x_data[0 ];
97164 }
@@ -101,8 +168,8 @@ class GroupNormKernel : public framework::OpKernel<T> {
101168
102169 x_mean /= number * imsize;
103170 x_var /= number * imsize;
104- x_var = x_var - x_mean * x_mean;
105- T var_inv = 1.0 / sqrt (x_var + epsilon);
171+ x_var = std::max ( x_var - x_mean * x_mean, T ( 0 )) ;
172+ T var_inv = T ( 1 ) / std:: sqrt (x_var + epsilon);
106173 mean_data[bid * groups + gid] = x_mean;
107174 var_data[bid * groups + gid] = x_var;
108175
0 commit comments