Skip to content

Commit 0a9937d

Browse files
authored
improve group norm cpu precision and performance (#33176)
* improve group norm cpu precision and performance * add unit test to group norm
1 parent 387f227 commit 0a9937d

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

paddle/fluid/operators/group_norm_op.h

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <algorithm>
17+
#include <array>
18+
#include <numeric>
1719
#include <string>
1820
#include "paddle/fluid/framework/data_layout.h"
1921
#include "paddle/fluid/framework/eigen.h"
@@ -73,6 +75,11 @@ class GroupNormKernel : public framework::OpKernel<T> {
7375
auto* iter_y_data = y_data;
7476
for (int bid = 0; bid < x_dims[0]; bid++) {
7577
for (int gid = 0; gid < groups; gid++) {
78+
const int64_t M = 8;
79+
std::array<T, M> x_mean_arr;
80+
std::array<T, M> x_var_arr;
81+
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
82+
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
7683
T x_mean = 0, x_var = 0;
7784
int number =
7885
std::min(group_size, static_cast<int>(C - gid * group_size));
@@ -83,15 +90,75 @@ class GroupNormKernel : public framework::OpKernel<T> {
8390

8491
if (data_layout == DataLayout::kNCHW) {
8592
for (int cid = 0; cid < number; cid++) {
86-
for (int imid = 0; imid < imsize; imid++, iter_x_data++) {
93+
int imid;
94+
for (imid = 0; imid < imsize - (imsize % M);
95+
imid += M, iter_x_data += M) {
96+
// TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used
97+
// in template class/function, before we complete high
98+
// performance cpu vector extension, temporarily unrolling
99+
// loop to get high precision and performance
100+
x_mean_arr[0] += iter_x_data[0];
101+
x_var_arr[0] += iter_x_data[0] * iter_x_data[0];
102+
x_mean_arr[1] += iter_x_data[1];
103+
x_var_arr[1] += iter_x_data[1] * iter_x_data[1];
104+
x_mean_arr[2] += iter_x_data[2];
105+
x_var_arr[2] += iter_x_data[2] * iter_x_data[2];
106+
x_mean_arr[3] += iter_x_data[3];
107+
x_var_arr[3] += iter_x_data[3] * iter_x_data[3];
108+
x_mean_arr[4] += iter_x_data[4];
109+
x_var_arr[4] += iter_x_data[4] * iter_x_data[4];
110+
x_mean_arr[5] += iter_x_data[5];
111+
x_var_arr[5] += iter_x_data[5] * iter_x_data[5];
112+
x_mean_arr[6] += iter_x_data[6];
113+
x_var_arr[6] += iter_x_data[6] * iter_x_data[6];
114+
x_mean_arr[7] += iter_x_data[7];
115+
x_var_arr[7] += iter_x_data[7] * iter_x_data[7];
116+
}
117+
x_mean =
118+
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
119+
x_var =
120+
std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
121+
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
122+
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
123+
for (; imid < imsize; imid++, iter_x_data++) {
87124
x_mean += iter_x_data[0];
88125
x_var += iter_x_data[0] * iter_x_data[0];
89126
}
90127
}
91128
} else {
92129
for (int cid = 0; cid < number; cid++) {
93130
iter_x_data = tmp_x + cid;
94-
for (int imid = 0; imid < imsize; imid++, iter_x_data += C) {
131+
int imid;
132+
for (imid = 0; imid < imsize - (imsize % M);
133+
imid += M, iter_x_data += M * C) {
134+
// TODO(gaoxiang) :Because AVX/AVX2/AVX512 can not directly used
135+
// in template class/function, before we complete high
136+
// performance cpu vector extension, temporarily unrolling
137+
// loop to get high precision and performance
138+
x_mean_arr[0] += iter_x_data[0 * C];
139+
x_var_arr[0] += iter_x_data[0 * C] * iter_x_data[0 * C];
140+
x_mean_arr[1] += iter_x_data[1 * C];
141+
x_var_arr[1] += iter_x_data[1 * C] * iter_x_data[1 * C];
142+
x_mean_arr[2] += iter_x_data[2 * C];
143+
x_var_arr[2] += iter_x_data[2 * C] * iter_x_data[2 * C];
144+
x_mean_arr[3] += iter_x_data[3 * C];
145+
x_var_arr[3] += iter_x_data[3 * C] * iter_x_data[3 * C];
146+
x_mean_arr[4] += iter_x_data[4 * C];
147+
x_var_arr[4] += iter_x_data[4 * C] * iter_x_data[4 * C];
148+
x_mean_arr[5] += iter_x_data[5 * C];
149+
x_var_arr[5] += iter_x_data[5 * C] * iter_x_data[5 * C];
150+
x_mean_arr[6] += iter_x_data[6 * C];
151+
x_var_arr[6] += iter_x_data[6 * C] * iter_x_data[6 * C];
152+
x_mean_arr[7] += iter_x_data[7 * C];
153+
x_var_arr[7] += iter_x_data[7 * C] * iter_x_data[7 * C];
154+
}
155+
x_mean =
156+
std::accumulate(x_mean_arr.cbegin(), x_mean_arr.cend(), x_mean);
157+
x_var =
158+
std::accumulate(x_var_arr.cbegin(), x_var_arr.cend(), x_var);
159+
std::fill(x_mean_arr.begin(), x_mean_arr.end(), T(0));
160+
std::fill(x_var_arr.begin(), x_var_arr.end(), T(0));
161+
for (; imid < imsize; imid++, iter_x_data += C) {
95162
x_mean += iter_x_data[0];
96163
x_var += iter_x_data[0] * iter_x_data[0];
97164
}
@@ -101,8 +168,8 @@ class GroupNormKernel : public framework::OpKernel<T> {
101168

102169
x_mean /= number * imsize;
103170
x_var /= number * imsize;
104-
x_var = x_var - x_mean * x_mean;
105-
T var_inv = 1.0 / sqrt(x_var + epsilon);
171+
x_var = std::max(x_var - x_mean * x_mean, T(0));
172+
T var_inv = T(1) / std::sqrt(x_var + epsilon);
106173
mean_data[bid * groups + gid] = x_mean;
107174
var_data[bid * groups + gid] = x_var;
108175

python/paddle/fluid/tests/unittests/test_group_norm_op_v2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def test_weight_bias_false():
5353
weight_attr=False,
5454
bias_attr=False)
5555

56+
def test_nn_exception():
57+
with fluid.dygraph.guard(p):
58+
59+
def attr_data_format():
60+
out = paddle.nn.GroupNorm(
61+
num_groups=2, num_channels=2, data_format="NHWC")
62+
63+
self.assertRaises(ValueError, attr_data_format)
64+
5665
x = np.random.randn(*shape).astype("float32")
5766
y1 = compute_v1(x)
5867
y2 = compute_v2(x)
@@ -61,6 +70,7 @@ def test_weight_bias_false():
6170
print("y1:", y1, "\ty2:", y2)
6271
self.assertTrue(result)
6372
test_weight_bias_false()
73+
test_nn_exception()
6474

6575
def test_static(self):
6676
places = [fluid.CPUPlace()]

python/paddle/nn/layer/norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(self,
375375
self._num_channels = num_channels
376376
self._num_groups = num_groups
377377
if data_format != 'NCHW':
378-
raise ValueError("unsupported data layout:" + data_layout)
378+
raise ValueError("unsupported data layout:" + data_format)
379379

380380
param_shape = [self._num_channels]
381381

0 commit comments

Comments
 (0)