Skip to content

Commit 67e47e6

Browse files
committed
refine batch_norm
1 parent 52119d6 commit 67e47e6

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

paddle/operators/batch_norm_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ class BatchNormOp : public framework::OperatorWithKernel {
5050
PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), "");
5151
PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), "");
5252

53-
const float epsilon = ctx->Attrs().Get<float>("epsilon");
54-
PADDLE_ENFORCE_GE(epsilon, 0.0, "epsilon should be larger than 0");
55-
PADDLE_ENFORCE_LE(epsilon, 0.001, "epsilon should not be too large");
56-
5753
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
5854
PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0],
5955
"Mean and MeanOut should share the same memory");
@@ -91,7 +87,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
9187
: OpProtoAndCheckerMaker(proto, op_checker) {
9288
AddAttr<bool>("is_test", "").SetDefault(false);
9389
AddAttr<float>("momentum", "").SetDefault(0.9);
94-
AddAttr<float>("epsilon", "").SetDefault(1e-5);
90+
AddAttr<float>("epsilon", "")
91+
.SetDefault(1e-5)
92+
.AddCustomChecker([](const float &epsilon) {
93+
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
94+
"'epsilon' should be between 0.0 and 0.001.");
95+
});
9596
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
9697
AddInput("X", "The input tensor");
9798
AddInput("Scale",

0 commit comments

Comments
 (0)