Skip to content

Commit dd23316

Browse files
committed
remove is_test=True in grad
1 parent 7a73692 commit dd23316

File tree

2 files changed

+5
-15
lines changed

2 files changed

+5
-15
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
575575
// SavedVariance have been reverted in forward operator
576576
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
577577
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
578-
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
578+
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
579579
const bool is_test = ctx.Attr<bool>("is_test");
580580
const float epsilon = ctx.Attr<float>("epsilon");
581581
const DataLayout data_layout =
@@ -585,6 +585,8 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
585585
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
586586
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
587587

588+
use_global_stats = is_test || use_global_stats;
589+
588590
// batch_norm with inplace as false will take X as grad input, which
589591
// is same as cuDNN batch_norm backward calculation, batch_norm
590592
// with inplace as true only take Y as input and X should be calculate
@@ -605,13 +607,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
605607
"X@GRAD and Y@GRAD inplaced in non-inplace mode"));
606608
}
607609

608-
PADDLE_ENFORCE_EQ(
609-
is_test, false,
610-
platform::errors::InvalidArgument(
611-
"`is_test = True` CANNOT be used in train program. If "
612-
"you want to use global status in pre_train model, "
613-
"please set `use_global_stats = True`"));
614-
615610
// Get the size for each dimension.
616611
// NCHW [batch_size, in_channels, in_height, in_width]
617612
const auto &x_dims = x->dims();

paddle/fluid/operators/batch_norm_op.cu

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
817817
platform::errors::InvalidArgument("It must use CUDAPlace."));
818818
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
819819
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
820-
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
820+
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
821821

822822
const DataLayout data_layout =
823823
framework::StringToDataLayout(data_layout_str);
@@ -850,12 +850,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
850850
}
851851

852852
const bool is_test = ctx.Attr<bool>("is_test");
853-
PADDLE_ENFORCE_EQ(
854-
is_test, false,
855-
platform::errors::InvalidArgument(
856-
"`is_test = True` CANNOT be used in train program. If "
857-
"you want to use global status in pre_train model, "
858-
"please set `use_global_stats = True`"));
853+
use_global_stats = is_test || use_global_stats;
859854

860855
const auto &x_dims = x->dims();
861856

0 commit comments

Comments
 (0)