Skip to content

Commit 95862a5

Browse files
author
chengduo
authored
Merge pull request #7014 from chengduoZH/profiling/refine_drop_out
Refine drop_out_op and batch_norm
2 parents 94096ae + 67e47e6 commit 95862a5

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

paddle/operators/batch_norm_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ class BatchNormOp : public framework::OperatorWithKernel {
5050
PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), "");
5151
PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), "");
5252

53-
const float epsilon = ctx->Attrs().Get<float>("epsilon");
54-
PADDLE_ENFORCE_GE(epsilon, 0.0, "epsilon should be larger than 0");
55-
PADDLE_ENFORCE_LE(epsilon, 0.001, "epsilon should not be too large");
56-
5753
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
5854
PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0],
5955
"Mean and MeanOut should share the same memory");
@@ -91,7 +87,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
9187
: OpProtoAndCheckerMaker(proto, op_checker) {
9288
AddAttr<bool>("is_test", "").SetDefault(false);
9389
AddAttr<float>("momentum", "").SetDefault(0.9);
94-
AddAttr<float>("epsilon", "").SetDefault(1e-5);
90+
AddAttr<float>("epsilon", "")
91+
.SetDefault(1e-5)
92+
.AddCustomChecker([](const float &epsilon) {
93+
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
94+
"'epsilon' should be between 0.0 and 0.001.");
95+
});
9596
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
9697
AddInput("X", "The input tensor");
9798
AddInput("Scale",

paddle/operators/dropout_op.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ class DropoutOp : public framework::OperatorWithKernel {
2525

2626
void InferShape(framework::InferShapeContext* ctx) const override {
2727
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
28-
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
29-
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
3028

3129
auto x_dims = ctx->GetInputDim("X");
3230
ctx->SetOutputDim("Out", x_dims);
@@ -47,7 +45,11 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
4745
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
4846

4947
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
50-
.SetDefault(.5f);
48+
.SetDefault(.5f)
49+
.AddCustomChecker([](const float& drop_p) {
50+
PADDLE_ENFORCE(drop_p >= 0.0f && drop_p <= 1.0f,
51+
"'dropout_prob' must be between 0.0 and 1.0.");
52+
});
5153
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
5254
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
5355

@@ -78,8 +80,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
7880
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
7981
"Input(Out@GRAD) must not be null.");
8082

81-
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
82-
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
8383
auto x_dims = ctx->GetInputDim("X");
8484
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
8585
PADDLE_ENFORCE_EQ(x_dims, out_dims,

paddle/operators/dropout_op.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,15 @@ struct MaskGenerator {
3030
__host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
3131
: dropout_prob(dropout_prob), seed(seed) {}
3232

33-
__host__ __device__ T operator()(const unsigned int n) const {
33+
inline __host__ __device__ T operator()(const unsigned int n) const {
3434
thrust::minstd_rand rng;
3535
rng.seed(seed);
3636
thrust::uniform_real_distribution<AttrType> dist(0, 1);
3737
rng.discard(n);
3838
if (dist(rng) < dropout_prob) {
3939
return static_cast<T>(0);
40-
} else {
41-
return static_cast<T>(1);
4240
}
41+
return static_cast<T>(1);
4342
}
4443
};
4544

0 commit comments

Comments
 (0)