@@ -1585,15 +1585,15 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
15851585 // The block index for the head.
15861586 const int bidh = blockIdx.z ;
15871587 constexpr int kBlockN = Kernel_traits::kBlockN ;
1588- if (params.num_splits == 1 ) { // means grid.x = 1, blockIdx.x = 0;
1589- int loop_step_x = 0 ;
1590- for (int i = 0 ; i < params.seqlen_k ; i+= kBlockN ) {
1591- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false , false , Is_attn_mask, /* Seq_parallel=*/ true >(params, bidb, bidh, loop_step_x);
1592- loop_step_x += 1 ;
1593- }
1594- } else {
1595- compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false , false , Is_attn_mask, /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
1596- }
1588+ // if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0;
1589+ // int loop_step_x = 0;
1590+ // for(int i = 0; i < params.seqlen_k; i+= kBlockN) {
1591+ // compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
1592+ // loop_step_x += 1;
1593+ // }
1594+ // } else {
1595+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false , false , Is_attn_mask, /* Seq_parallel=*/ true >(params, bidb, bidh, n_block);
1596+ // }
15971597}
15981598
15991599// //////////////////////////////////////////////////////////////////////////////////////////////////
0 commit comments