@@ -295,8 +295,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
295295 bool global_stats = test_mode || use_global_stats;
296296
297297 const std::string data_layout_str = ctx.Attr <std::string>(" data_layout" );
298- const DataLayout data_layout =
299- framework::StringToDataLayout (data_layout_str);
298+ DataLayout data_layout = framework::StringToDataLayout (data_layout_str);
300299
301300 const auto *x = ctx.Input <Tensor>(" X" );
302301 const auto &x_dims = x->dims ();
@@ -332,6 +331,12 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
332331 saved_mean->mutable_data <T>(ctx.GetPlace ());
333332 saved_variance->mutable_data <T>(ctx.GetPlace ());
334333
334+ // input dimension is 2 and the format is NCHW. The input can be regarded
335+ // as NHWC format
336+ if (x_dims.size () == 2 && data_layout == DataLayout::kNCHW ) {
337+ data_layout = DataLayout::kNHWC ;
338+ }
339+
335340 if (!global_stats) {
336341 // saved_xx is use just in this batch of data
337342 EigenVectorArrayMap<T> saved_mean_e (
@@ -578,8 +583,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
578583 bool use_global_stats = ctx.Attr <bool >(" use_global_stats" );
579584 const bool is_test = ctx.Attr <bool >(" is_test" );
580585 const float epsilon = ctx.Attr <float >(" epsilon" );
581- const DataLayout data_layout =
582- framework::StringToDataLayout (data_layout_str);
586+ DataLayout data_layout = framework::StringToDataLayout (data_layout_str);
583587
584588 auto *d_x = ctx.Output <Tensor>(framework::GradVarName (" X" ));
585589 auto *d_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
@@ -633,6 +637,12 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
633637 : x_dims[x_dims.size () - 1 ]);
634638 const int sample_size = x->numel () / N / C;
635639
640+ // input dimension is 2 and the format is NCHW. The input can be regarded as
641+ // NHWC format
642+ if (x_dims.size () == 2 && data_layout == DataLayout::kNCHW ) {
643+ data_layout = DataLayout::kNHWC ;
644+ }
645+
636646 // init output
637647 if (d_x) {
638648 d_x->mutable_data <T>(ctx.GetPlace ());
0 commit comments