@@ -575,7 +575,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
575575 // SavedVariance have been reverted in forward operator
576576 const auto *saved_inv_variance = ctx.Input <Tensor>(" SavedVariance" );
577577 const std::string data_layout_str = ctx.Attr <std::string>(" data_layout" );
578- const bool use_global_stats = ctx.Attr <bool >(" use_global_stats" );
578+ bool use_global_stats = ctx.Attr <bool >(" use_global_stats" );
579579 const bool is_test = ctx.Attr <bool >(" is_test" );
580580 const float epsilon = ctx.Attr <float >(" epsilon" );
581581 const DataLayout data_layout =
@@ -585,6 +585,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
585585 auto *d_scale = ctx.Output <Tensor>(framework::GradVarName (" Scale" ));
586586 auto *d_bias = ctx.Output <Tensor>(framework::GradVarName (" Bias" ));
587587
588+ use_global_stats = is_test || use_global_stats;
589+
588590 // batch_norm with inplace as false will take X as grad input, which
589591 // is same as cuDNN batch_norm backward calculation, batch_norm
590592 // with inplace as true only take Y as input and X should be calculate
@@ -605,13 +607,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
605607 " X@GRAD and Y@GRAD inplaced in non-inplace mode" ));
606608 }
607609
608- PADDLE_ENFORCE_EQ (
609- is_test, false ,
610- platform::errors::InvalidArgument (
611- " `is_test = True` CANNOT be used in train program. If "
612- " you want to use global status in pre_train model, "
613- " please set `use_global_stats = True`" ));
614-
615610 // Get the size for each dimension.
616611 // NCHW [batch_size, in_channels, in_height, in_width]
617612 const auto &x_dims = x->dims ();
0 commit comments