Skip to content

Commit 0fa5933

Browse files
Revert num_splits in flash_bwd_kernel.h for large model (#21)
1 parent b74460b commit 0fa5933

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

csrc/flash_attn/src/flash_bwd_kernel.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,15 +1585,15 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
15851585
// The block index for the head.
15861586
const int bidh = blockIdx.z;
15871587
constexpr int kBlockN = Kernel_traits::kBlockN;
1588-
if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0;
1589-
int loop_step_x = 0;
1590-
for(int i = 0; i < params.seqlen_k; i+= kBlockN) {
1591-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
1592-
loop_step_x += 1;
1593-
}
1594-
} else {
1595-
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
1596-
}
1588+
//if (params.num_splits == 1) { // means grid.x = 1, blockIdx.x = 0;
1589+
// int loop_step_x = 0;
1590+
// for(int i = 0; i < params.seqlen_k; i+= kBlockN) {
1591+
// compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, loop_step_x);
1592+
// loop_step_x += 1;
1593+
// }
1594+
//} else {
1595+
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, Is_attn_mask, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
1596+
//}
15971597
}
15981598

15991599
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)