Skip to content

Commit 8864579

Browse files
authored
Merge branch 'vllm-project:main' into apo
2 parents 07f01ba + 57b4e68 commit 8864579

1 file changed

Lines changed: 2 additions & 5 deletions

File tree

hopper/flash_fwd_combine_kernel.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ class FlashAttnFwdCombine {
292292

293293
switch (choose_scheduling_algo(args)) {
294294
case SchedulingAlgo::STANDARD: {
295-
unsigned int num_blocks_k = cute::ceil_div(args.dv, kBlockK);
296295
unsigned int num_blocks_m = cute::ceil_div(args.seqlen_q * args.num_heads, kBlockM);
297296
return {num_blocks_m, num_blocks_k, static_cast<unsigned int>(args.b)};
298297
}
@@ -426,15 +425,13 @@ class FlashAttnFwdCombine {
426425
*params.semaphore_to_reset = 0;
427426
}
428427

428+
if (batch >= params.b) { return; }
429429
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
430430
int const offset = seqlen_info.offset;
431431
int const seqlen = seqlen_info.seqlen;
432432
int max_idx = seqlen * get<2>(params.shape_LSE_partial);
433433

434-
bool block_coord_valid =
435-
block_coord.block_m < cute::ceil_div(max_idx, Int<kBlockM>{}) &&
436-
block_coord.bidb < params.b;
437-
if (!block_coord_valid) { return; }
434+
if (m_block >= cute::ceil_div(max_idx, Int<kBlockM>{})) { return; }
438435

439436
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
440437
if (num_splits <= 1) { return; }

0 commit comments

Comments
 (0)