@@ -50,10 +50,6 @@ class BatchNormOp : public framework::OperatorWithKernel {
5050 PADDLE_ENFORCE (ctx->HasOutput (" SavedMean" ), " " );
5151 PADDLE_ENFORCE (ctx->HasOutput (" SavedVariance" ), " " );
5252
53- const float epsilon = ctx->Attrs ().Get <float >(" epsilon" );
54- PADDLE_ENFORCE_GE (epsilon, 0.0 , " epsilon should be larger than 0" );
55- PADDLE_ENFORCE_LE (epsilon, 0.001 , " epsilon should not be too large" );
56-
5753 // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
5854 PADDLE_ENFORCE_EQ (ctx->Inputs (" Mean" )[0 ], ctx->Outputs (" MeanOut" )[0 ],
5955 " Mean and MeanOut should share the same memory" );
@@ -91,7 +87,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
9187 : OpProtoAndCheckerMaker(proto, op_checker) {
9288 AddAttr<bool >(" is_test" , " " ).SetDefault (false );
9389 AddAttr<float >(" momentum" , " " ).SetDefault (0.9 );
94- AddAttr<float >(" epsilon" , " " ).SetDefault (1e-5 );
90+ AddAttr<float >(" epsilon" , " " )
91+ .SetDefault (1e-5 )
92+ .AddCustomChecker ([](const float &epsilon) {
93+ PADDLE_ENFORCE (epsilon >= 0 .0f && epsilon <= 0 .001f ,
94+ " 'epsilon' should be between 0.0 and 0.001." );
95+ });
9596 AddAttr<std::string>(" data_layout" , " " ).SetDefault (" NCHW" );
9697 AddInput (" X" , " The input tensor" );
9798 AddInput (" Scale" ,
0 commit comments