@@ -924,6 +924,7 @@ void FusedAttentionInferMeta(const MetaTensor& x,
924924 }
925925
926926 out->set_dims (x.dims ());
927+ out->set_dtype (x.dtype ());
927928}
928929
929930void FusedAttentionGradInferMeta (const MetaTensor& out_grad,
@@ -998,19 +999,19 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
998999 " GradOp is only callable when is_test is false" ));
9991000
10001001 if (!pre_layer_norm) {
1001- if (ln_scale_2_grad) {
1002+ if (ln_scale_2_grad && ln_scale_2 ) {
10021003 ln_scale_2_grad->set_dims (ln_scale_2.dims ());
10031004 }
1004- if (ln_bias_2_grad) {
1005+ if (ln_bias_2_grad && ln_bias_2 ) {
10051006 ln_bias_2_grad->set_dims (ln_bias_2.dims ());
10061007 }
10071008 }
10081009
1009- if (pre_layer_norm) {
1010+ if (pre_layer_norm && ln_scale ) {
10101011 if (ln_scale_grad) {
10111012 ln_scale_grad->set_dims (ln_scale.dims ());
10121013 }
1013- if (ln_bias_grad) {
1014+ if (ln_bias_grad && ln_bias ) {
10141015 ln_bias_grad->set_dims (ln_bias.dims ());
10151016 }
10161017 }
@@ -1019,7 +1020,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
10191020 x_grad->set_dims (x.dims ());
10201021 }
10211022
1022- if (out_linear_bias_grad) {
1023+ if (out_linear_bias_grad && out_linear_bias ) {
10231024 out_linear_bias_grad->set_dims (out_linear_bias.dims ());
10241025 }
10251026
@@ -1031,7 +1032,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
10311032 qkv_weight_grad->set_dims (qkv_weight.dims ());
10321033 }
10331034
1034- if (qkv_bias_grad) {
1035+ if (qkv_bias_grad && qkv_bias ) {
10351036 qkv_bias_grad->set_dims (qkv_bias.dims ());
10361037 }
10371038
@@ -1040,7 +1041,7 @@ void FusedAttentionGradInferMeta(const MetaTensor& out_grad,
10401041 ln_out_grad->set_dims (ln_out.dims ());
10411042 }
10421043 } else {
1043- if (bias_dropout_residual_out_grad) {
1044+ if (bias_dropout_residual_out_grad && bias_dropout_residual_out ) {
10441045 bias_dropout_residual_out_grad->set_dims (
10451046 bias_dropout_residual_out.dims ());
10461047 }
@@ -1556,36 +1557,36 @@ void FusedFeedForwardGradInferMeta(const MetaTensor& out_grad,
15561557 bool add_residual,
15571558 int ring_id,
15581559 MetaTensor* x_grad,
1559- MetaTensor* ln1_scale_grad,
1560- MetaTensor* ln1_bias_grad,
1561- MetaTensor* ln2_scale_grad,
1562- MetaTensor* ln2_bias_grad,
15631560 MetaTensor* linear1_weight_grad,
15641561 MetaTensor* linear1_bias_grad,
15651562 MetaTensor* linear2_weight_grad,
1566- MetaTensor* linear2_bias_grad) {
1563+ MetaTensor* linear2_bias_grad,
1564+ MetaTensor* ln1_scale_grad,
1565+ MetaTensor* ln1_bias_grad,
1566+ MetaTensor* ln2_scale_grad,
1567+ MetaTensor* ln2_bias_grad) {
15671568 auto d_out_dim = out_grad.dims ();
15681569 x_grad->set_dims (d_out_dim);
1569- if (ln1_scale_grad) {
1570+ if (ln1_scale_grad && ln1_scale ) {
15701571 ln1_scale_grad->set_dims (ln1_scale.dims ());
15711572 }
1572- if (ln1_bias_grad) {
1573+ if (ln1_bias_grad && ln1_bias ) {
15731574 ln1_bias_grad->set_dims (ln1_bias.dims ());
15741575 }
1575- if (ln2_scale_grad) {
1576+ if (ln2_scale_grad && ln2_scale ) {
15761577 ln2_scale_grad->set_dims (ln2_scale.dims ());
15771578 }
1578- if (ln2_bias_grad) {
1579+ if (ln2_bias_grad && ln2_bias ) {
15791580 ln2_bias_grad->set_dims (ln2_bias.dims ());
15801581 }
15811582
15821583 linear1_weight_grad->set_dims (linear1_weight.dims ());
1583- if (linear1_bias_grad) {
1584+ if (linear1_bias_grad && linear1_bias ) {
15841585 linear1_bias_grad->set_dims (linear1_bias.dims ());
15851586 }
15861587
15871588 linear2_weight_grad->set_dims (linear2_weight.dims ());
1588- if (linear2_bias_grad) {
1589+ if (linear2_bias_grad && linear2_bias ) {
15891590 linear2_bias_grad->set_dims (linear2_bias.dims ());
15901591 }
15911592}
0 commit comments