@@ -160,50 +160,54 @@ template <typename DeviceContext, typename T>
160160class DropoutGradKernel : public framework ::OpKernel<T> {
161161 public:
162162 void Compute (const framework::ExecutionContext& context) const override {
163- PADDLE_ENFORCE_EQ (!context.Attr <bool >(" is_test" ), true ,
164- platform::errors::PreconditionNotMet (
165- " GradOp is only callable when is_test is false" ));
166-
167163 auto * grad_x = context.Output <Tensor>(framework::GradVarName (" X" ));
168164 auto * grad_y = context.Input <Tensor>(framework::GradVarName (" Out" ));
169165 auto * mask = context.Input <Tensor>(" Mask" );
170166 grad_x->mutable_data <T>(context.GetPlace ());
171167 auto size = grad_x->numel ();
172168
173- auto M = EigenVector<uint8_t >::Flatten (*mask);
174169 auto dX = EigenVector<T>::Flatten (*grad_x);
175170 auto dY = EigenVector<T>::Flatten (*grad_y);
176171
177172 auto & place =
178173 *context.template device_context <DeviceContext>().eigen_device ();
179174 auto & dropout_implementation =
180175 context.Attr <std::string>(" dropout_implementation" );
181- if (dropout_implementation == " upscale_in_train" ) {
182- float dropout_prob = context.Attr <float >(" dropout_prob" );
183- if (dropout_prob == 1 .0f ) {
184- dX.device (place) = static_cast <T>(0 ) * dY;
176+ if (context.Attr <bool >(" is_test" ) == true ) {
177+ if (dropout_implementation == " upscale_in_train" ) {
178+ dX.device (place) = static_cast <T>(1 ) * dY;
185179 } else {
186- int vec_size = VectorizedSize<T>(grad_y->data <T>());
187- if (platform::is_gpu_place (context.GetPlace ()) && vec_size == 4 &&
188- size % 4 == 0 ) {
180+ float dropout_prob = context.Attr <float >(" dropout_prob" );
181+ dX.device (place) = dY * static_cast <T>(1 .0f - dropout_prob);
182+ }
183+ } else {
184+ auto M = EigenVector<uint8_t >::Flatten (*mask);
185+ if (dropout_implementation == " upscale_in_train" ) {
186+ float dropout_prob = context.Attr <float >(" dropout_prob" );
187+ if (dropout_prob == 1 .0f ) {
188+ dX.device (place) = static_cast <T>(0 ) * dY;
189+ } else {
190+ int vec_size = VectorizedSize<T>(grad_y->data <T>());
191+ if (platform::is_gpu_place (context.GetPlace ()) && vec_size == 4 &&
192+ size % 4 == 0 ) {
189193#if defined(__NVCC__) || defined(__HIPCC__)
190- auto factor = static_cast <T>(1 .0f / (1 .0f - dropout_prob));
191- auto stream = context.cuda_device_context ().stream ();
192- platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D (
193- context.cuda_device_context (), size);
194- DropoutGradCUDAKernel<
195- T, uint8_t ,
196- 4 ><<<config.block_per_grid , config.thread_per_block , 0 , stream>>>(
197- grad_y->data <T>(), mask->data <uint8_t >(), factor, size,
198- grad_x->data <T>());
194+ auto factor = static_cast <T>(1 .0f / (1 .0f - dropout_prob));
195+ auto stream = context.cuda_device_context ().stream ();
196+ platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D (
197+ context.cuda_device_context (), size);
198+ DropoutGradCUDAKernel<T, uint8_t , 4 ><<<
199+ config.block_per_grid , config.thread_per_block , 0 , stream>>>(
200+ grad_y->data <T>(), mask->data <uint8_t >(), factor, size,
201+ grad_x->data <T>());
199202#endif
200- } else {
201- dX.device (place) =
202- dY * M.cast <T>() / static_cast <T>(1 .0f - dropout_prob);
203+ } else {
204+ dX.device (place) =
205+ dY * M.cast <T>() / static_cast <T>(1 .0f - dropout_prob);
206+ }
203207 }
208+ } else {
209+ dX.device (place) = dY * M.cast <T>();
204210 }
205- } else {
206- dX.device (place) = dY * M.cast <T>();
207211 }
208212 }
209213};
0 commit comments