@@ -34,20 +34,20 @@ class GroupNormComputeTest : public arena::TestCase {
3434 DDim dims_{{4 , 5 , 19 , 19 }};
3535 float epsilon_ = 1e-5f ;
3636 int groups_ = 1 ;
37- int channels_ = dims_[ 1 ] ;
37+ std::string data_layout_str_ = " NCHW " ;
3838
3939 public:
4040 GroupNormComputeTest (const Place& place,
4141 const std::string& alias,
4242 DDim dims,
4343 float epsilon,
4444 int groups,
45- int channels )
45+ std::string data_layout_str )
4646 : TestCase(place, alias),
4747 dims_ (dims),
4848 epsilon_(epsilon),
4949 groups_(groups),
50- channels_(channels ) {}
50+ data_layout_str_(data_layout_str ) {}
5151
5252 void RunBaseline (Scope* scope) override {
5353 auto x = scope->FindTensor (x_);
@@ -59,7 +59,7 @@ class GroupNormComputeTest : public arena::TestCase {
5959 CHECK (y);
6060 CHECK (saved_mean);
6161 CHECK (saved_variance);
62- DDim saved_dim ({dims_[0 ] * groups_});
62+ DDim saved_dim ({dims_[0 ], groups_});
6363 y->Resize (dims_);
6464 saved_mean->Resize (saved_dim);
6565 saved_variance->Resize (saved_dim);
@@ -68,49 +68,82 @@ class GroupNormComputeTest : public arena::TestCase {
6868 auto scale_data = scale->data <float >();
6969 auto bias_data = bias->data <float >();
7070 auto y_data = y->mutable_data <float >();
71- auto saved_mean_data = saved_mean->mutable_data <float >();
72- auto saved_variance_data = saved_variance->mutable_data <float >();
73-
74- int n = x->dims ()[0 ];
75- int ch_per_group = channels_ / groups_;
76- CHECK_EQ (x->dims ()[1 ], channels_);
77- int spatial_size = ch_per_group * x->dims ()[2 ] * x->dims ()[3 ];
78- // compute mean
79- for (int i = 0 ; i < n * groups_; ++i) {
80- const float * x_ptr = x_data + i * spatial_size;
81- float sum = 0 .f ;
82- for (int j = 0 ; j < spatial_size; ++j) {
83- sum += x_ptr[j];
84- }
85- saved_mean_data[i] = sum / spatial_size;
86- }
87- // compute variance
88- for (int i = 0 ; i < n * groups_; ++i) {
89- const float * x_ptr = x_data + i * spatial_size;
90- float sum = 0 .f ;
91- for (int j = 0 ; j < spatial_size; ++j) {
92- sum +=
93- (x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]);
94- }
95- saved_variance_data[i] = 1 .f / sqrtf (sum / spatial_size + epsilon_);
96- }
97- int in_size = x->dims ()[2 ] * x->dims ()[3 ];
98- // compute out
99- for (int i = 0 ; i < n * groups_; ++i) {
100- const float * x_ptr = x_data + i * spatial_size;
101- float * y_ptr = y_data + i * spatial_size;
102- int c_num = i % groups_;
103- for (int c = 0 ; c < ch_per_group; c++) {
104- int chin = c_num * ch_per_group + c;
105- float scale_val = scale_data[chin];
106- float bias_val = bias_data[chin];
107- const float * x_ch_ptr = x_ptr + c * in_size;
108- float * y_ch_ptr = y_ptr + c * in_size;
109- for (int j = 0 ; j < in_size; j++) {
110- y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) *
111- saved_variance_data[i] +
112- bias_val;
71+ auto mean_data = saved_mean->mutable_data <float >();
72+ auto var_data = saved_variance->mutable_data <float >();
73+
74+ auto x_dims = x->dims ();
75+ int groups = groups_;
76+ int channels =
77+ (data_layout_str_ == " NCHW" ) ? x_dims[1 ] : x_dims[x_dims.size () - 1 ];
78+ int group_size = (channels - 1 ) / groups + 1 ;
79+ int imsize = (data_layout_str_ == " NCHW" ) ? (x_dims[2 ] * x_dims[3 ])
80+ : (x_dims[1 ] * x_dims[2 ]);
81+
82+ auto * iter_x_data = x_data;
83+ auto * iter_y_data = y_data;
84+ for (int bid = 0 ; bid < x_dims[0 ]; bid++) {
85+ for (int gid = 0 ; gid < groups; gid++) {
86+ float x_mean = 0 ;
87+ float x_var = 0 ;
88+ int number =
89+ std::min (group_size, static_cast <int >(channels - gid * group_size));
90+ auto * tmp_x = iter_x_data;
91+ auto * x_src_data = iter_x_data;
92+ auto * tmp_y = iter_y_data;
93+ auto * y_src_data = iter_y_data;
94+
95+ if (data_layout_str_ == " NCHW" ) {
96+ for (int cid = 0 ; cid < number; cid++) {
97+ for (int imid = 0 ; imid < imsize; imid++, iter_x_data++) {
98+ x_mean += iter_x_data[0 ];
99+ x_var += iter_x_data[0 ] * iter_x_data[0 ];
100+ }
101+ }
102+ } else {
103+ for (int cid = 0 ; cid < number; cid++) {
104+ iter_x_data = tmp_x + cid;
105+ for (int imid = 0 ; imid < imsize; imid++, iter_x_data += channels) {
106+ x_mean += iter_x_data[0 ];
107+ x_var += iter_x_data[0 ] * iter_x_data[0 ];
108+ }
109+ }
110+ iter_x_data = tmp_x + group_size;
113111 }
112+
113+ x_mean /= number * imsize;
114+ x_var /= number * imsize;
115+ x_var = x_var - x_mean * x_mean;
116+ float var_inv = 1.0 / std::sqrt (x_var + epsilon_);
117+ mean_data[bid * groups + gid] = x_mean;
118+ var_data[bid * groups + gid] = x_var;
119+
120+ if (data_layout_str_ == " NCHW" ) {
121+ for (int cid = 0 ; cid < number; cid++) {
122+ for (int imid = 0 ; imid < imsize; imid++, tmp_x++, iter_y_data++) {
123+ float val = (tmp_x[0 ] - x_mean) * var_inv;
124+ if (scale_data) val *= scale_data[gid * group_size + cid];
125+ if (bias_data) val += bias_data[gid * group_size + cid];
126+ iter_y_data[0 ] = val;
127+ }
128+ }
129+ } else {
130+ for (int cid = 0 ; cid < number; cid++) {
131+ tmp_x = x_src_data + cid;
132+ iter_y_data = y_src_data + cid;
133+ for (int imid = 0 ; imid < imsize;
134+ imid++, tmp_x += channels, iter_y_data += channels) {
135+ float val = (tmp_x[0 ] - x_mean) * var_inv;
136+ if (scale_data) val *= scale_data[gid * group_size + cid];
137+ if (bias_data) val += bias_data[gid * group_size + cid];
138+ iter_y_data[0 ] = val;
139+ }
140+ }
141+ iter_y_data = tmp_y + group_size;
142+ }
143+ }
144+ if (data_layout_str_ == " NCHW" ) {
145+ iter_x_data = x_data + (bid + 1 ) * channels * imsize;
146+ iter_y_data = y_data + (bid + 1 ) * channels * imsize;
114147 }
115148 }
116149 }
@@ -125,7 +158,7 @@ class GroupNormComputeTest : public arena::TestCase {
125158 op_desc->SetOutput (" Variance" , {saved_variance_});
126159 op_desc->SetAttr (" epsilon" , epsilon_);
127160 op_desc->SetAttr (" groups" , groups_);
128- op_desc->SetAttr (" channels " , channels_ );
161+ op_desc->SetAttr (" data_layout " , data_layout_str_ );
129162 }
130163
131164 void PrepareData () override {
@@ -148,7 +181,7 @@ void TestGroupNorm(Place place,
148181 float abs_error = 6e-5 ,
149182 std::vector<std::string> ignored_outs = {}) {
150183 for (auto & n : {1 , 3 , 16 }) {
151- for (auto & c : {1 }) {
184+ for (auto & c : {1 , 2 }) {
152185 for (auto & h : {1 , 16 , 33 , 56 }) {
153186 for (auto & w : {1 , 17 , 55 }) {
154187 for (auto & groups : {1 , 2 , 4 }) {
@@ -158,7 +191,7 @@ void TestGroupNorm(Place place,
158191 DDim dim_in ({n, c, h, w});
159192 float epsilon = 1e-5f ;
160193 std::unique_ptr<arena::TestCase> tester (new GroupNormComputeTest (
161- place, " def" , dim_in, epsilon, groups, c ));
194+ place, " def" , dim_in, epsilon, groups, " NCHW " ));
162195#ifdef LITE_WITH_ARM
163196 if (place == TARGET (kARM )) {
164197 auto & ctx = tester->context ()->As <ARMContext>();
0 commit comments