@@ -171,6 +171,210 @@ class FusedFeedForwardKernel : public framework::OpKernel<T> {
171171 }
172172};
173173
174+ template <typename DeviceContext, typename T>
175+ class FusedFeedForwardGradKernel : public framework ::OpKernel<T> {
176+ public:
177+ void MatMulGrad (const platform::CUDADeviceContext& ctx,
178+ const framework::Tensor& d_out, const framework::Tensor& a,
179+ const framework::Tensor& b, framework::Tensor* d_a,
180+ framework::Tensor* d_b) const {
181+ auto blas = math::GetBlas<DeviceContext, T>(ctx);
182+ auto a_2d = FoldInitDims (a);
183+ auto b_2d = FoldInitDims (b);
184+ auto mat_dim_a = math::CreateMatrixDescriptor (a_2d.dims (), 0 , true );
185+ auto mat_dim_b = math::CreateMatrixDescriptor (b_2d.dims (), 0 , true );
186+ auto mat_dim_dout = math::CreateMatrixDescriptor (d_out.dims (), 0 , false );
187+ T alpha = static_cast <T>(1.0 );
188+ blas.MatMul (d_out, mat_dim_dout, b, mat_dim_b, alpha, d_a, T (0 ));
189+ blas.MatMul (a, mat_dim_a, d_out, mat_dim_dout, alpha, d_b, T (0 ));
190+ }
191+
192+ void FFNGrad (
193+ const framework::Tensor& d_out, const framework::Tensor& x,
194+ const framework::Tensor& dropout1_mask,
195+ const framework::Tensor& dropout2_mask,
196+ const framework::Tensor& linear1_out, const framework::Tensor& ln1_out,
197+ const framework::Tensor& dropout1_out,
198+ const framework::Tensor& dropout2_out,
199+ const framework::Tensor& linear1_weight,
200+ const framework::Tensor* linear1_bias,
201+ const framework::Tensor& linear2_weight,
202+ const framework::Tensor* ln1_gamma, const framework::Tensor* ln1_beta,
203+ const framework::Tensor& ln1_mean, const framework::Tensor& ln1_variance,
204+ const framework::Tensor* ln2_gamma, const framework::Tensor* ln2_beta,
205+ const framework::Tensor& ln2_mean, const framework::Tensor& ln2_variance,
206+ framework::Tensor* d_x, framework::Tensor* d_linear1_weight,
207+ framework::Tensor* d_linear1_bias, framework::Tensor* d_linear2_weight,
208+ framework::Tensor* d_linear2_bias, framework::Tensor* d_ln1_gamma,
209+ framework::Tensor* d_ln1_beta, framework::Tensor* d_ln2_gamma,
210+ framework::Tensor* d_ln2_beta, const int bsz_seq, const int d_model,
211+ const int dim_feedforward, const DropoutParam& dropout_param1,
212+ const DropoutParam& dropout_param2, const std::string& act_method,
213+ const bool pre_layer_norm, const float epsilon1, const float epsilon2,
214+ const platform::CUDADeviceContext& ctx) const {
215+ FusedDropoutLayerNormHelper<T, uint8_t > pre_layernorm_helper (
216+ bsz_seq, d_model, epsilon1);
217+ FusedDropoutHelper<T, uint8_t > fused_act_dropout_helper (
218+ ctx, bsz_seq, dim_feedforward, dropout_param1);
219+ FusedDropoutLayerNormHelper<T, uint8_t > fused_dropout_layernorm_helper (
220+ ctx, bsz_seq, d_model, dropout_param2, epsilon2);
221+
222+ auto place = ctx.GetPlace ();
223+ using U = LayerNormParamType<T>;
224+ const U* ln1_gamma_ptr =
225+ ln1_gamma == nullptr ? nullptr : ln1_gamma->data <U>();
226+ const U* ln1_beta_ptr = ln1_beta == nullptr ? nullptr : ln1_beta->data <U>();
227+ const U* ln2_gamma_ptr =
228+ ln2_gamma == nullptr ? nullptr : ln2_gamma->data <U>();
229+ const U* ln2_beta_ptr = ln2_beta == nullptr ? nullptr : ln2_beta->data <U>();
230+ const T* linear1_bias_ptr =
231+ linear1_bias == nullptr ? nullptr : linear1_bias->data <T>();
232+ T* d_linear1_bias_ptr =
233+ d_linear1_bias == nullptr ? nullptr : d_linear1_bias->data <T>();
234+ T* d_linear2_bias_ptr =
235+ d_linear2_bias == nullptr ? nullptr : d_linear2_bias->data <T>();
236+ U* d_ln1_gamma_ptr =
237+ d_ln1_gamma == nullptr ? nullptr : d_ln1_gamma->data <U>();
238+ U* d_ln1_beta_ptr = d_ln1_beta == nullptr ? nullptr : d_ln1_beta->data <U>();
239+ U* d_ln2_gamma_ptr =
240+ d_ln2_gamma == nullptr ? nullptr : d_ln2_gamma->data <U>();
241+ U* d_ln2_beta_ptr = d_ln2_beta == nullptr ? nullptr : d_ln2_beta->data <U>();
242+
243+ framework::Tensor d_linear2_out, d_dropout2_out, d_residual;
244+ d_linear2_out.mutable_data <T>({bsz_seq, d_model}, place);
245+ d_dropout2_out.mutable_data <T>({bsz_seq, d_model}, place);
246+ d_residual.mutable_data <T>({bsz_seq, d_model}, place);
247+
248+ if (pre_layer_norm) {
249+ fused_dropout_layernorm_helper.ResidualDropoutBiasGrad (
250+ ctx, d_out.data <T>(), dropout2_mask.data <uint8_t >(),
251+ d_linear2_out.data <T>(), d_residual.data <T>(), d_linear2_bias_ptr);
252+ } else {
253+ fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad (
254+ ctx, d_out.data <T>(), dropout2_out.data <T>(),
255+ dropout2_mask.data <uint8_t >(), ln2_gamma_ptr, ln2_mean.data <U>(),
256+ ln2_variance.data <U>(), d_dropout2_out.data <T>(), d_ln2_gamma_ptr,
257+ d_ln2_beta_ptr, d_linear2_out.data <T>(), d_linear2_bias_ptr,
258+ d_residual.data <T>());
259+ }
260+
261+ framework::Tensor d_dropout1_out;
262+ d_dropout1_out.mutable_data <T>({bsz_seq, dim_feedforward}, place);
263+ MatMulGrad (ctx, d_linear2_out, dropout1_out, linear2_weight,
264+ &d_dropout1_out, d_linear2_weight);
265+
266+ framework::Tensor d_linear1_out;
267+ d_linear1_out.mutable_data <T>({bsz_seq, dim_feedforward}, place);
268+ fused_act_dropout_helper.DropoutActBiasGrad (
269+ ctx, d_dropout1_out.data <T>(), linear1_out.data <T>(), linear1_bias_ptr,
270+ dropout1_mask.data <uint8_t >(), d_linear1_out.data <T>(),
271+ d_linear1_bias_ptr, act_method);
272+
273+ if (pre_layer_norm) {
274+ framework::Tensor d_ln1_out;
275+ d_ln1_out.mutable_data <T>({bsz_seq, d_model}, place);
276+ MatMulGrad (ctx, d_linear1_out, ln1_out, linear1_weight, &d_ln1_out,
277+ d_linear1_weight);
278+
279+ pre_layernorm_helper.LayerNormGrad (ctx, d_ln1_out.data <T>(), x.data <T>(),
280+ ln1_gamma_ptr, ln1_mean.data <U>(),
281+ ln1_variance.data <U>(), d_x->data <T>(),
282+ d_ln1_gamma_ptr, d_ln1_beta_ptr);
283+ } else {
284+ MatMulGrad (ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight);
285+ }
286+ }
287+
288+ void Compute (const framework::ExecutionContext& context) const override {
289+ using U = LayerNormParamType<T>;
290+ auto d_out =
291+ *context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
292+ auto x = *context.Input <framework::Tensor>(" X" );
293+ auto dropout1_mask = *context.Input <framework::Tensor>(" Dropout1Mask" );
294+ auto dropout2_mask = *context.Input <framework::Tensor>(" Dropout2Mask" );
295+ auto linear1_out = *context.Input <framework::Tensor>(" Linear1Out" );
296+ auto ln1_out = *context.Input <framework::Tensor>(" Ln1Out" );
297+ auto dropout1_out = *context.Input <framework::Tensor>(" Dropout1Out" );
298+ auto dropout2_out = *context.Input <framework::Tensor>(" Dropout2Out" );
299+ auto linear1_weight = *context.Input <framework::Tensor>(" Linear1Weight" );
300+ auto * linear1_bias = context.Input <framework::Tensor>(" Linear1Bias" );
301+ auto linear2_weight = *context.Input <framework::Tensor>(" Linear2Weight" );
302+ auto ln1_mean = *context.Input <framework::Tensor>(" Ln1Mean" );
303+ auto ln1_variance = *context.Input <framework::Tensor>(" Ln1Variance" );
304+ auto * ln1_scale = context.Input <framework::Tensor>(" Ln1Scale" );
305+ auto * ln1_bias = context.Input <framework::Tensor>(" Ln1Bias" );
306+ auto ln2_mean = *context.Input <framework::Tensor>(" Ln2Mean" );
307+ auto ln2_variance = *context.Input <framework::Tensor>(" Ln2Variance" );
308+ auto * ln2_scale = context.Input <framework::Tensor>(" Ln2Scale" );
309+ auto * ln2_bias = context.Input <framework::Tensor>(" Ln2Bias" );
310+
311+ auto * d_x = context.Output <framework::Tensor>(framework::GradVarName (" X" ));
312+ auto * d_ln1_scale =
313+ context.Output <framework::Tensor>(framework::GradVarName (" Ln1Scale" ));
314+ auto * d_ln1_bias =
315+ context.Output <framework::Tensor>(framework::GradVarName (" Ln1Bias" ));
316+ auto * d_ln2_scale =
317+ context.Output <framework::Tensor>(framework::GradVarName (" Ln2Scale" ));
318+ auto * d_ln2_bias =
319+ context.Output <framework::Tensor>(framework::GradVarName (" Ln2Bias" ));
320+ auto * d_linear1_weight = context.Output <framework::Tensor>(
321+ framework::GradVarName (" Linear1Weight" ));
322+ auto * d_linear1_bias = context.Output <framework::Tensor>(
323+ framework::GradVarName (" Linear1Bias" ));
324+ auto * d_linear2_weight = context.Output <framework::Tensor>(
325+ framework::GradVarName (" Linear2Weight" ));
326+ auto * d_linear2_bias = context.Output <framework::Tensor>(
327+ framework::GradVarName (" Linear2Bias" ));
328+
329+ const float epsilon1 = context.Attr <float >(" ln1_epsilon" );
330+ const float epsilon2 = context.Attr <float >(" ln2_epsilon" );
331+ const bool pre_layer_norm = context.Attr <bool >(" pre_layer_norm" );
332+ const std::string act_method = context.Attr <std::string>(" act_method" );
333+ DropoutParam dropout_param1 (context, 1 );
334+ DropoutParam dropout_param2 (context, 2 );
335+
336+ auto place = context.GetPlace ();
337+ d_x->mutable_data <T>(place);
338+ if (d_ln1_scale) {
339+ d_ln1_scale->mutable_data <U>(place);
340+ }
341+ if (d_ln1_bias) {
342+ d_ln1_bias->mutable_data <U>(place);
343+ }
344+ if (d_ln2_scale) {
345+ d_ln2_scale->mutable_data <U>(place);
346+ }
347+ if (d_ln2_bias) {
348+ d_ln2_bias->mutable_data <U>(place);
349+ }
350+ if (d_linear1_bias) {
351+ d_linear1_bias->mutable_data <T>(place);
352+ }
353+ if (d_linear2_bias) {
354+ d_linear2_bias->mutable_data <T>(place);
355+ }
356+ d_linear1_weight->mutable_data <T>(place);
357+ d_linear2_weight->mutable_data <T>(place);
358+
359+ auto x_dim = x.dims ();
360+ auto mat_dim_x =
361+ math::CreateMatrixDescriptor (RowMatrixFromVector (x_dim), 0 , false );
362+
363+ auto linear1_weight_dim = linear1_weight.dims ();
364+ int d_model = linear1_weight_dim[0 ];
365+ int dim_feedforward = linear1_weight_dim[linear1_weight_dim.size () - 1 ];
366+ int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_ ;
367+
368+ FFNGrad (d_out, x, dropout1_mask, dropout2_mask, linear1_out, ln1_out,
369+ dropout1_out, dropout2_out, linear1_weight, linear1_bias,
370+ linear2_weight, ln1_scale, ln1_bias, ln1_mean, ln1_variance,
371+ ln2_scale, ln2_bias, ln2_mean, ln2_variance, d_x, d_linear1_weight,
372+ d_linear1_bias, d_linear2_weight, d_linear2_bias, d_ln1_scale,
373+ d_ln1_bias, d_ln2_scale, d_ln2_bias, bsz_seq, d_model,
374+ dim_feedforward, dropout_param1, dropout_param2, act_method,
375+ pre_layer_norm, epsilon1, epsilon2, context.cuda_device_context ());
376+ }
377+ };
174378} // namespace operators
175379} // namespace paddle
176380
@@ -181,3 +385,10 @@ REGISTER_OP_CUDA_KERNEL(
181385 ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext, double >,
182386 ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext,
183387 paddle::platform::float16>);
388+ REGISTER_OP_CUDA_KERNEL (
389+ fused_feedforward_grad,
390+ ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext, float >,
391+ ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
392+ double >,
393+ ops::FusedFeedForwardGradKernel<paddle::platform::CUDADeviceContext,
394+ paddle::platform::float16>);
0 commit comments