@@ -328,9 +328,206 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
328328 }
329329};
330330
331+ class FusedAttentionGradOp : public framework ::OperatorWithKernel {
332+ public:
333+ using framework::OperatorWithKernel::OperatorWithKernel;
334+
335+ void InferShape (framework::InferShapeContext *ctx) const override {
336+ PADDLE_ENFORCE_EQ (
337+ ctx->Attrs ().Get <bool >(" attn_dropout_is_test" ), false ,
338+ platform::errors::InvalidArgument (
339+ " GradOp is only callable when attn_dropout_is_test is false" ));
340+
341+ OP_INOUT_CHECK (ctx->HasInput (" Ln2Mean" ), " Input" , " Ln2Mean" ,
342+ " FusedAttentionGrad" );
343+ OP_INOUT_CHECK (ctx->HasInput (" Ln2Variance" ), " Input" , " Ln2Variance" ,
344+ " FusedAttentionGrad" );
345+ if (ctx->HasOutput (framework::GradVarName (" Ln2Scale" ))) {
346+ ctx->SetOutputDim (framework::GradVarName (" Ln2Scale" ),
347+ ctx->GetInputDim (" Ln2Scale" ));
348+ }
349+ if (ctx->HasOutput (framework::GradVarName (" Ln2Bias" ))) {
350+ ctx->SetOutputDim (framework::GradVarName (" Ln2Bias" ),
351+ ctx->GetInputDim (" Ln2Bias" ));
352+ }
353+ OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " FusedAttentionGrad" );
354+ OP_INOUT_CHECK (ctx->HasInput (" LnMean" ), " Input" , " LnMean" ,
355+ " FusedAttentionGrad" );
356+ OP_INOUT_CHECK (ctx->HasInput (" LnVariance" ), " Input" , " LnVariance" ,
357+ " FusedAttentionGrad" );
358+ if (ctx->Attrs ().Get <bool >(" pre_layer_norm" ) == true ) {
359+ OP_INOUT_CHECK (ctx->HasInput (" LnOut" ), " Input" , " LnOut" ,
360+ " FusedAttentionGrad" );
361+ }
362+ OP_INOUT_CHECK (ctx->HasInput (" QKVW" ), " Input" , " QKVW" ,
363+ " FusedAttentionGrad" );
364+ OP_INOUT_CHECK (ctx->HasInput (" QKVBias" ), " Input" , " QKVBias" ,
365+ " FusedAttentionGrad" );
366+ OP_INOUT_CHECK (ctx->HasInput (" SrcMask" ), " Input" , " SrcMask" ,
367+ " FusedAttentionGrad" );
368+ OP_INOUT_CHECK (ctx->HasInput (" OutLinearW" ), " Input" , " OutLinearW" ,
369+ " FusedAttentionGrad" );
370+ OP_INOUT_CHECK (ctx->HasInput (" OutLinearBias" ), " Input" , " OutLinearBias" ,
371+ " FusedAttentionGrad" );
372+
373+ if (ctx->HasOutput (framework::GradVarName (" LnScale" ))) {
374+ ctx->SetOutputDim (framework::GradVarName (" LnScale" ),
375+ ctx->GetInputDim (" LnScale" ));
376+ }
377+ if (ctx->HasOutput (framework::GradVarName (" LnBias" ))) {
378+ ctx->SetOutputDim (framework::GradVarName (" LnBias" ),
379+ ctx->GetInputDim (" LnBias" ));
380+ }
381+ if (ctx->HasOutput (framework::GradVarName (" X" ))) {
382+ ctx->SetOutputDim (framework::GradVarName (" X" ), ctx->GetInputDim (" X" ));
383+ }
384+
385+ ctx->SetOutputDim (framework::GradVarName (" OutLinearBias" ),
386+ ctx->GetInputDim (" OutLinearBias" ));
387+ ctx->SetOutputDim (framework::GradVarName (" OutLinearW" ),
388+ ctx->GetInputDim (" OutLinearW" ));
389+ ctx->SetOutputDim (framework::GradVarName (" QKVW" ), ctx->GetInputDim (" QKVW" ));
390+ ctx->SetOutputDim (framework::GradVarName (" QKVBias" ),
391+ ctx->GetInputDim (" QKVBias" ));
392+
393+ ctx->SetOutputDim (framework::GradVarName (" LnOut" ),
394+ ctx->GetInputDim (" LnOut" ));
395+ ctx->SetOutputDim (framework::GradVarName (" FMHAOut" ),
396+ ctx->GetInputDim (" FMHAOut" ));
397+ ctx->SetOutputDim (framework::GradVarName (" QKTVOut" ),
398+ ctx->GetInputDim (" QKTVOut" ));
399+ ctx->SetOutputDim (framework::GradVarName (" TransposeOut2" ),
400+ ctx->GetInputDim (" TransposeOut2" ));
401+ ctx->SetOutputDim (framework::GradVarName (" QKOut" ),
402+ ctx->GetInputDim (" QKOut" ));
403+ ctx->SetOutputDim (framework::GradVarName (" SoftmaxOut" ),
404+ ctx->GetInputDim (" SoftmaxOut" ));
405+ ctx->SetOutputDim (framework::GradVarName (" AttnDropoutOut" ),
406+ ctx->GetInputDim (" AttnDropoutOut" ));
407+ ctx->SetOutputDim (framework::GradVarName (" SrcMaskOut" ),
408+ ctx->GetInputDim (" SrcMaskOut" ));
409+ ctx->SetOutputDim (framework::GradVarName (" QKVOut" ),
410+ ctx->GetInputDim (" QKVOut" ));
411+ ctx->SetOutputDim (framework::GradVarName (" QKVBiasOut" ),
412+ ctx->GetInputDim (" QKVBiasOut" ));
413+ ctx->SetOutputDim (framework::GradVarName (" OutLinearOut" ),
414+ ctx->GetInputDim (" OutLinearOut" ));
415+ ctx->SetOutputDim (framework::GradVarName (" BiasDropoutResidualOut" ),
416+ ctx->GetInputDim (" BiasDropoutResidualOut" ));
417+ }
418+
419+ protected:
420+ framework::OpKernelType GetExpectedKernelType (
421+ const framework::ExecutionContext &ctx) const override {
422+ auto input = ctx.Input <Tensor>(" X" );
423+ auto input_data_type = input->type ();
424+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
425+ }
426+ };
427+
428+ template <typename T>
429+ class FusedAttentionGradOpMaker : public framework ::SingleGradOpMaker<T> {
430+ public:
431+ using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
432+
433+ protected:
434+ void Apply (GradOpPtr<T> op) const override {
435+ op->SetType (" fused_attention_grad" );
436+ op->SetInput (framework::GradVarName (" Y" ), this ->OutputGrad (" Y" ));
437+
438+ // inputs x, parameters and their grad.
439+ op->SetInput (" X" , this ->Input (" X" ));
440+ op->SetInput (" QKVW" , this ->Input (" QKVW" ));
441+ op->SetInput (" QKVBias" , this ->Input (" QKVBias" ));
442+ op->SetInput (" SrcMask" , this ->Input (" SrcMask" ));
443+ op->SetInput (" OutLinearW" , this ->Input (" OutLinearW" ));
444+ op->SetInput (" OutLinearBias" , this ->Input (" OutLinearBias" ));
445+ if (this ->HasInput (" LnScale" )) {
446+ op->SetInput (" LnScale" , this ->Input (" LnScale" ));
447+ op->SetOutput (framework::GradVarName (" LnScale" ),
448+ this ->InputGrad (" LnScale" ));
449+ }
450+ if (this ->HasInput (" LnBias" )) {
451+ op->SetInput (" LnBias" , this ->Input (" LnBias" ));
452+ op->SetOutput (framework::GradVarName (" LnBias" ),
453+ this ->InputGrad (" LnBias" ));
454+ }
455+ if (this ->HasInput (" Ln2Scale" )) {
456+ op->SetInput (" Ln2Scale" , this ->Input (" Ln2Scale" ));
457+ op->SetOutput (framework::GradVarName (" Ln2Scale" ),
458+ this ->InputGrad (" Ln2Scale" ));
459+ }
460+ if (this ->HasInput (" Ln2Bias" )) {
461+ op->SetInput (" Ln2Bias" , this ->Input (" Ln2Bias" ));
462+ op->SetOutput (framework::GradVarName (" Ln2Bias" ),
463+ this ->InputGrad (" Ln2Bias" ));
464+ }
465+
466+ op->SetOutput (framework::GradVarName (" X" ), this ->InputGrad (" X" ));
467+ op->SetOutput (framework::GradVarName (" QKVW" ), this ->InputGrad (" QKVW" ));
468+ op->SetOutput (framework::GradVarName (" QKVBias" ),
469+ this ->InputGrad (" QKVBias" ));
470+ op->SetOutput (framework::GradVarName (" OutLinearBias" ),
471+ this ->InputGrad (" OutLinearBias" ));
472+ op->SetOutput (framework::GradVarName (" OutLinearW" ),
473+ this ->InputGrad (" OutLinearW" ));
474+
475+ // use forward outputs as backward inputs.
476+ op->SetInput (" LnOut" , this ->Output (" LnOut" ));
477+ op->SetInput (" LnMean" , this ->Output (" LnMean" ));
478+ op->SetInput (" LnVariance" , this ->Output (" LnVariance" ));
479+ op->SetInput (" QKVOut" , this ->Output (" QKVOut" ));
480+ op->SetInput (" QKVBiasOut" , this ->Output (" QKVBiasOut" ));
481+ op->SetInput (" TransposeOut2" , this ->Output (" TransposeOut2" ));
482+ op->SetInput (" QKOut" , this ->Output (" QKOut" ));
483+ op->SetInput (" QKTVOut" , this ->Output (" QKTVOut" ));
484+ op->SetInput (" SoftmaxOut" , this ->Output (" SoftmaxOut" ));
485+ op->SetInput (" AttnDropoutMaskOut" , this ->Output (" AttnDropoutMaskOut" ));
486+ op->SetInput (" AttnDropoutOut" , this ->Output (" AttnDropoutOut" ));
487+ op->SetInput (" SrcMaskOut" , this ->Output (" SrcMaskOut" ));
488+ op->SetInput (" FMHAOut" , this ->Output (" FMHAOut" ));
489+ op->SetInput (" OutLinearOut" , this ->Output (" OutLinearOut" ));
490+
491+ op->SetInput (" Ln2Mean" , this ->Output (" Ln2Mean" ));
492+ op->SetInput (" Ln2Variance" , this ->Output (" Ln2Variance" ));
493+ op->SetInput (" DropoutMaskOut" , this ->Output (" DropoutMaskOut" ));
494+ op->SetInput (" BiasDropoutResidualOut" ,
495+ this ->Output (" BiasDropoutResidualOut" ));
496+ op->SetInput (" QKVOut" , this ->Output (" QKVOut" ));
497+
498+ // backward outputs: dinput
499+ op->SetOutput (framework::GradVarName (" LnOut" ), this ->OutputGrad (" LnOut" ));
500+ op->SetOutput (framework::GradVarName (" QKVOut" ), this ->OutputGrad (" QKVOut" ));
501+ op->SetOutput (framework::GradVarName (" QKVBiasOut" ),
502+ this ->OutputGrad (" QKVBiasOut" ));
503+ op->SetOutput (framework::GradVarName (" QKTVOut" ),
504+ this ->OutputGrad (" QKTVOut" ));
505+ op->SetOutput (framework::GradVarName (" TransposeOut2" ),
506+ this ->OutputGrad (" TransposeOut2" ));
507+ op->SetOutput (framework::GradVarName (" QKOut" ), this ->OutputGrad (" QKOut" ));
508+ op->SetOutput (framework::GradVarName (" SoftmaxOut" ),
509+ this ->OutputGrad (" SoftmaxOut" ));
510+ op->SetOutput (framework::GradVarName (" AttnDropoutOut" ),
511+ this ->OutputGrad (" AttnDropoutOut" ));
512+ op->SetOutput (framework::GradVarName (" SrcMaskOut" ),
513+ this ->OutputGrad (" SrcMaskOut" ));
514+ op->SetOutput (framework::GradVarName (" FMHAOut" ),
515+ this ->OutputGrad (" FMHAOut" ));
516+ op->SetOutput (framework::GradVarName (" BiasDropoutResidualOut" ),
517+ this ->OutputGrad (" BiasDropoutResidualOut" ));
518+ op->SetOutput (framework::GradVarName (" OutLinearOut" ),
519+ this ->OutputGrad (" OutLinearOut" ));
520+
521+ op->SetAttrMap (this ->Attrs ());
522+ }
523+ };
524+
331525} // namespace operators
332526} // namespace paddle
333527
334528namespace ops = paddle::operators;
335529REGISTER_OPERATOR (fused_attention, ops::FusedAttentionOp,
336- ops::FusedAttentionOpMaker);
530+ ops::FusedAttentionOpMaker,
531+ ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
532+ ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
533+ REGISTER_OPERATOR (fused_attention_grad, ops::FusedAttentionGradOp);
0 commit comments