@@ -25,8 +25,6 @@ class DropoutOp : public framework::OperatorWithKernel {
2525
2626 void InferShape (framework::InferShapeContext* ctx) const override {
2727 PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) must not be null." );
28- PADDLE_ENFORCE_GE (ctx->Attrs ().Get <float >(" dropout_prob" ), 0 );
29- PADDLE_ENFORCE_LE (ctx->Attrs ().Get <float >(" dropout_prob" ), 1 );
3028
3129 auto x_dims = ctx->GetInputDim (" X" );
3230 ctx->SetOutputDim (" Out" , x_dims);
@@ -47,7 +45,11 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
4745 AddOutput (" Mask" , " The random sampled dropout mask." ).AsIntermediate ();
4846
4947 AddAttr<float >(" dropout_prob" , " Probability of setting units to zero." )
50- .SetDefault (.5f );
48+ .SetDefault (.5f )
49+ .AddCustomChecker ([](const float & drop_p) {
50+ PADDLE_ENFORCE (drop_p >= 0 .0f && drop_p <= 1 .0f ,
51+ " 'dropout_prob' must be between 0.0 and 1.0." );
52+ });
5153 AddAttr<bool >(" is_test" , " True if in test phase." ).SetDefault (false );
5254 AddAttr<int >(" seed" , " Dropout random seed." ).SetDefault (0 );
5355
@@ -78,8 +80,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
7880 PADDLE_ENFORCE (ctx->HasInput (framework::GradVarName (" Out" )),
7981 " Input(Out@GRAD) must not be null." );
8082
81- PADDLE_ENFORCE_GE (ctx->Attrs ().Get <float >(" dropout_prob" ), 0 );
82- PADDLE_ENFORCE_LE (ctx->Attrs ().Get <float >(" dropout_prob" ), 1 );
8383 auto x_dims = ctx->GetInputDim (" X" );
8484 auto out_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
8585 PADDLE_ENFORCE_EQ (x_dims, out_dims,
0 commit comments