Skip to content

Commit f1275fb

Browse files
authored
Support dropout backward in eval mode (#35122)
* Support dropout backward in eval mode * add downscale case * minor fix * minor fix
1 parent e7df47e commit f1275fb

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

paddle/fluid/operators/dropout_op.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
117117
using framework::OperatorWithKernel::OperatorWithKernel;
118118

119119
void InferShape(framework::InferShapeContext* ctx) const override {
120-
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
121-
platform::errors::InvalidArgument(
122-
"GradOp is only callable when is_test is false"));
123-
124120
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "DropoutGrad");
125121
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
126122
framework::GradVarName("Out"), "DropoutGrad");

paddle/fluid/operators/dropout_op.h

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -160,50 +160,54 @@ template <typename DeviceContext, typename T>
160160
class DropoutGradKernel : public framework::OpKernel<T> {
161161
public:
162162
void Compute(const framework::ExecutionContext& context) const override {
163-
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
164-
platform::errors::PreconditionNotMet(
165-
"GradOp is only callable when is_test is false"));
166-
167163
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
168164
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
169165
auto* mask = context.Input<Tensor>("Mask");
170166
grad_x->mutable_data<T>(context.GetPlace());
171167
auto size = grad_x->numel();
172168

173-
auto M = EigenVector<uint8_t>::Flatten(*mask);
174169
auto dX = EigenVector<T>::Flatten(*grad_x);
175170
auto dY = EigenVector<T>::Flatten(*grad_y);
176171

177172
auto& place =
178173
*context.template device_context<DeviceContext>().eigen_device();
179174
auto& dropout_implementation =
180175
context.Attr<std::string>("dropout_implementation");
181-
if (dropout_implementation == "upscale_in_train") {
182-
float dropout_prob = context.Attr<float>("dropout_prob");
183-
if (dropout_prob == 1.0f) {
184-
dX.device(place) = static_cast<T>(0) * dY;
176+
if (context.Attr<bool>("is_test") == true) {
177+
if (dropout_implementation == "upscale_in_train") {
178+
dX.device(place) = static_cast<T>(1) * dY;
185179
} else {
186-
int vec_size = VectorizedSize<T>(grad_y->data<T>());
187-
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
188-
size % 4 == 0) {
180+
float dropout_prob = context.Attr<float>("dropout_prob");
181+
dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
182+
}
183+
} else {
184+
auto M = EigenVector<uint8_t>::Flatten(*mask);
185+
if (dropout_implementation == "upscale_in_train") {
186+
float dropout_prob = context.Attr<float>("dropout_prob");
187+
if (dropout_prob == 1.0f) {
188+
dX.device(place) = static_cast<T>(0) * dY;
189+
} else {
190+
int vec_size = VectorizedSize<T>(grad_y->data<T>());
191+
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
192+
size % 4 == 0) {
189193
#if defined(__NVCC__) || defined(__HIPCC__)
190-
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
191-
auto stream = context.cuda_device_context().stream();
192-
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
193-
context.cuda_device_context(), size);
194-
DropoutGradCUDAKernel<
195-
T, uint8_t,
196-
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
197-
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
198-
grad_x->data<T>());
194+
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
195+
auto stream = context.cuda_device_context().stream();
196+
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
197+
context.cuda_device_context(), size);
198+
DropoutGradCUDAKernel<T, uint8_t, 4><<<
199+
config.block_per_grid, config.thread_per_block, 0, stream>>>(
200+
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
201+
grad_x->data<T>());
199202
#endif
200-
} else {
201-
dX.device(place) =
202-
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
203+
} else {
204+
dX.device(place) =
205+
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
206+
}
203207
}
208+
} else {
209+
dX.device(place) = dY * M.cast<T>();
204210
}
205-
} else {
206-
dX.device(place) = dY * M.cast<T>();
207211
}
208212
}
209213
};

0 commit comments

Comments
 (0)