Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions paddle/fluid/operators/dropout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));

OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "DropoutGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DropoutGrad");
Expand Down
56 changes: 30 additions & 26 deletions paddle/fluid/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,50 +160,54 @@ template <typename DeviceContext, typename T>
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(!context.Attr<bool>("is_test"), true,
platform::errors::PreconditionNotMet(
"GradOp is only callable when is_test is false"));

auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));
auto* mask = context.Input<Tensor>("Mask");
grad_x->mutable_data<T>(context.GetPlace());
auto size = grad_x->numel();

auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(*grad_y);

auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& dropout_implementation =
context.Attr<std::string>("dropout_implementation");
if (dropout_implementation == "upscale_in_train") {
float dropout_prob = context.Attr<float>("dropout_prob");
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
if (context.Attr<bool>("is_test") == true) {
if (dropout_implementation == "upscale_in_train") {
dX.device(place) = static_cast<T>(1) * dY;
} else {
int vec_size = VectorizedSize<T>(grad_y->data<T>());
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
size % 4 == 0) {
float dropout_prob = context.Attr<float>("dropout_prob");
dX.device(place) = dY / static_cast<T>(1.0f - dropout_prob);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是dY * static_cast<T>(1.0f - dropout_prob);

}
} else {
auto M = EigenVector<uint8_t>::Flatten(*mask);
if (dropout_implementation == "upscale_in_train") {
float dropout_prob = context.Attr<float>("dropout_prob");
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
} else {
int vec_size = VectorizedSize<T>(grad_y->data<T>());
if (platform::is_gpu_place(context.GetPlace()) && vec_size == 4 &&
size % 4 == 0) {
#if defined(__NVCC__) || defined(__HIPCC__)
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = context.cuda_device_context().stream();
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
context.cuda_device_context(), size);
DropoutGradCUDAKernel<
T, uint8_t,
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
grad_x->data<T>());
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
auto stream = context.cuda_device_context().stream();
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
context.cuda_device_context(), size);
DropoutGradCUDAKernel<T, uint8_t, 4><<<
config.block_per_grid, config.thread_per_block, 0, stream>>>(
grad_y->data<T>(), mask->data<uint8_t>(), factor, size,
grad_x->data<T>());
#endif
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
} else {
dX.device(place) =
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
} else {
dX.device(place) = dY * M.cast<T>();
}
} else {
dX.device(place) = dY * M.cast<T>();
}
}
};
Expand Down