@@ -96,9 +96,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
9696
9797 auto *x_data = input_x->data <T>();
9898 auto *qkv_weight_data = qkv_weight->data <T>();
99- auto *qkv_bias_data = qkv_bias->data <T>();
99+ auto *qkv_bias_data = (qkv_bias == nullptr ) ? nullptr : qkv_bias->data <T>();
100100 auto *qkv_out_data = qkv_out->mutable_data <T>(ctx.GetPlace ());
101- auto *qkv_bias_out_data = qkv_bias_out->mutable_data <T>(ctx.GetPlace ());
101+ auto *qkv_bias_out_data =
102+ (qkv_bias == nullptr ) ? nullptr
103+ : qkv_bias_out->mutable_data <T>(ctx.GetPlace ());
102104
103105 // get data ptr for FMHA.
104106 auto *transpose_out_2_data =
@@ -117,7 +119,8 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
117119
118120 // get data ptr for out_linear.
119121 auto *out_linear_weight_data = out_linear_weight->data <T>();
120- auto *out_linear_bias_data = out_linear_bias->data <T>();
122+ auto *out_linear_bias_data =
123+ (out_linear_bias == nullptr ) ? nullptr : out_linear_bias->data <T>();
121124 auto *out_linear_out_data = out_linear_out->mutable_data <T>(ctx.GetPlace ());
122125
123126 // get data ptr for bias+dropout+residual+layernorm
@@ -139,9 +142,15 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
139142
140143 auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context (),
141144 epsilon, bsz_seq, dim_embed);
145+
146+ bool compute_bias = true ;
147+ if (qkv_bias == nullptr ) {
148+ compute_bias = false ;
149+ }
142150 // (transA, transB, compute_bias) = (false, true, true)
143- auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context (), false , true ,
144- bsz_seq, output_size, input_size, true );
151+ auto qkv_compute =
152+ AttnMatMul<T>(ctx.cuda_device_context (), false , true , bsz_seq,
153+ output_size, input_size, compute_bias);
145154
146155 AttnDropoutParam attn_dropout_param (
147156 is_test_1, dropout_implementation_1, attn_dropout_rate,
@@ -176,10 +185,17 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
176185 qkv_compute.ComputeForward (qkv_weight, input_x, qkv_bias, qkv_out,
177186 qkv_bias_out);
178187 }
179- fmha_ref_compute.ComputeForward (*qkv_bias_out, src_mask, transpose_out_2,
180- qk_out, src_mask_out, softmax_out,
181- attn_dropout_mask_out, attn_dropout_out,
182- qktv_out, fmha_out);
188+ if (qkv_bias == nullptr ) {
189+ fmha_ref_compute.ComputeForward (*qkv_out, src_mask, transpose_out_2,
190+ qk_out, src_mask_out, softmax_out,
191+ attn_dropout_mask_out, attn_dropout_out,
192+ qktv_out, fmha_out);
193+ } else {
194+ fmha_ref_compute.ComputeForward (*qkv_bias_out, src_mask, transpose_out_2,
195+ qk_out, src_mask_out, softmax_out,
196+ attn_dropout_mask_out, attn_dropout_out,
197+ qktv_out, fmha_out);
198+ }
183199
184200 // fmha_out: [batch_size, seq_len, num_head, head_dim]
185201 // weight: [embed_dim, embed_dim]
@@ -249,9 +265,10 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
249265 auto *out_linear_bias = ctx.Input <Tensor>(" OutLinearBias" );
250266 auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data <T>());
251267 auto *qkv_weight_data = qkv_weight->data <T>();
252- auto *qkv_bias_data = qkv_bias->data <T>();
268+ auto *qkv_bias_data = (qkv_bias == nullptr ) ? nullptr : qkv_bias->data <T>();
253269 auto *out_linear_weight_data = out_linear_weight->data <T>();
254- auto *out_linear_bias_data = out_linear_bias->data <T>();
270+ auto *out_linear_bias_data =
271+ (out_linear_bias == nullptr ) ? nullptr : out_linear_bias->data <T>();
255272
256273 // fw output
257274 auto *fmha_out = ctx.Input <Tensor>(" FMHAOut" );
@@ -299,8 +316,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
299316 auto *d_bias_dropout_residual_out =
300317 ctx.Output <Tensor>(framework::GradVarName (" BiasDropoutResidualOut" ));
301318 auto *d_x_data = d_x->mutable_data <T>(ctx.GetPlace ());
302- auto *d_qkv_out_data = d_qkv_out->mutable_data <T>(ctx.GetPlace ());
303- auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data <T>(ctx.GetPlace ());
319+ // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
320+ // space can be reused.
321+ auto *d_qkv_out_data = (d_qkv_bias_out != nullptr )
322+ ? nullptr
323+ : d_qkv_out->mutable_data <T>(ctx.GetPlace ());
324+ auto *d_qkv_bias_out_data =
325+ (d_qkv_bias_out == nullptr )
326+ ? nullptr
327+ : d_qkv_bias_out->mutable_data <T>(ctx.GetPlace ());
304328 auto *d_qktv_out_data = d_qktv_out->mutable_data <T>(ctx.GetPlace ());
305329 auto *d_transpose_out_2_data =
306330 d_transpose_out_2->mutable_data <T>(ctx.GetPlace ());
@@ -326,11 +350,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
326350 auto *d_ln_2_bias = ctx.Output <Tensor>(framework::GradVarName (" Ln2Bias" ));
327351
328352 auto *d_qkv_weight_data = d_qkv_weight->mutable_data <T>(ctx.GetPlace ());
329- auto *d_qkv_bias_data = d_qkv_bias->mutable_data <T>(ctx.GetPlace ());
353+ auto *d_qkv_bias_data = (d_qkv_bias == nullptr )
354+ ? nullptr
355+ : d_qkv_bias->mutable_data <T>(ctx.GetPlace ());
330356 auto *d_out_linear_weight_data =
331357 d_out_linear_weight->mutable_data <T>(ctx.GetPlace ());
332358 auto *d_out_linear_bias_data =
333- d_out_linear_bias->mutable_data <T>(ctx.GetPlace ());
359+ (d_out_linear_bias == nullptr )
360+ ? nullptr
361+ : d_out_linear_bias->mutable_data <T>(ctx.GetPlace ());
334362
335363 const auto input_x_dims = input_x->dims ();
336364 const auto qkv_w_dims = qkv_weight->dims ();
@@ -352,12 +380,15 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
352380
353381 bool transA = false ;
354382 bool transB = true ;
355- bool compute_bias = true ;
383+ bool compute_qkv_bias = true ;
384+ if (qkv_bias == nullptr ) {
385+ compute_qkv_bias = false ;
386+ }
356387 auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context (),
357388 epsilon, bsz_seq, dim_embed);
358389 auto qkv_compute =
359390 AttnMatMul<T>(ctx.cuda_device_context (), transA, transB, bsz_seq,
360- output_size, input_size, compute_bias );
391+ output_size, input_size, compute_qkv_bias );
361392 AttnDropoutParam attn_dropout_param (
362393 is_test_1, dropout_implementation_1, attn_dropout_prob,
363394 is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
@@ -367,7 +398,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
367398 output_size = hidden_size;
368399 transA = false ;
369400 transB = false ;
370- compute_bias = false ;
401+ bool compute_bias = false ;
371402 auto out_linear_compute =
372403 AttnMatMul<T>(ctx.cuda_device_context (), transA, transB, bsz_seq,
373404 output_size, input_size, compute_bias);
@@ -405,14 +436,19 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
405436 d_out_linear_out, d_fmha_out,
406437 d_out_linear_weight, nullptr );
407438
408- fmha_ref_compute.ComputeBackward (
409- *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
410- *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
411- d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
412- d_transpose_out_2, nullptr , d_qkv_bias_out);
413- cudaMemcpyAsync (d_qkv_out_data, d_qkv_bias_out_data,
414- bsz_seq * 3 * num_head * dim_head * sizeof (T),
415- cudaMemcpyDeviceToDevice);
439+ if (qkv_bias != nullptr ) {
440+ fmha_ref_compute.ComputeBackward (
441+ *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
442+ *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
443+ d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
444+ d_transpose_out_2, nullptr , d_qkv_bias_out);
445+ } else {
446+ fmha_ref_compute.ComputeBackward (
447+ *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
448+ *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
449+ d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
450+ d_transpose_out_2, nullptr , d_qkv_out);
451+ }
416452
417453 if (pre_layer_norm) {
418454 auto *ln_mean = ctx.Input <Tensor>(" LnMean" );
@@ -432,15 +468,24 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
432468 auto *d_ln_bias_data =
433469 (d_ln_bias == nullptr ? nullptr
434470 : d_ln_bias->mutable_data <U>(ctx.GetPlace ()));
435-
436- qkv_compute.ComputeBackward (ln_out, qkv_weight, d_qkv_bias_out, d_ln_out,
437- d_qkv_weight, d_qkv_bias);
471+ if (qkv_bias != nullptr ) {
472+ qkv_compute.ComputeBackward (ln_out, qkv_weight, d_qkv_bias_out,
473+ d_ln_out, d_qkv_weight, d_qkv_bias);
474+ } else {
475+ qkv_compute.ComputeBackward (ln_out, qkv_weight, d_qkv_out, d_ln_out,
476+ d_qkv_weight, d_qkv_bias);
477+ }
438478 layer_norm_compute.ComputeBackward (x_data, d_ln_out_data, ln_scale_data,
439479 ln_mean_data, ln_var_data, d_x_data,
440480 d_ln_scale_data, d_ln_bias_data);
441481 } else {
442- qkv_compute.ComputeBackward (input_x, qkv_weight, d_qkv_bias_out, d_x,
443- d_qkv_weight, d_qkv_bias);
482+ if (qkv_bias != nullptr ) {
483+ qkv_compute.ComputeBackward (input_x, qkv_weight, d_qkv_bias_out, d_x,
484+ d_qkv_weight, d_qkv_bias);
485+ } else {
486+ qkv_compute.ComputeBackward (input_x, qkv_weight, d_qkv_out, d_x,
487+ d_qkv_weight, d_qkv_bias);
488+ }
444489 }
445490 // gradient accumulation
446491 std::vector<const Tensor *> ins;
0 commit comments