@@ -868,7 +868,7 @@ CUBIN_EXPORT __global__
868868#endif
869869
870870#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
871- uint32_t local_skipped_block_count = 0 ;
871+ uint32_t localSkippedBlockCount = 0 ;
872872#endif
873873
874874 // QK gemm
@@ -1014,7 +1014,7 @@ CUBIN_EXPORT __global__
10141014 {
10151015 smem.skipSoftmaxVotesGemm0ToGemm1 [idxXBuf] = 1U ;
10161016#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
1017- local_skipped_block_count ++;
1017+ localSkippedBlockCount ++;
10181018#endif
10191019 }
10201020 asm volatile (" fence.proxy.async.shared::cta;\n " ); // maybe not used
@@ -1081,9 +1081,9 @@ CUBIN_EXPORT __global__
10811081 unused (xBar.produced .arrive ());
10821082 }
10831083#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
1084- if (threadIdx .x == 0 && skipped_block_count != nullptr && total_block_count != nullptr )
1084+ if (threadIdx .x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr )
10851085 {
1086- atomicAdd (skippedBlockCount, local_skipped_block_count );
1086+ atomicAdd (skippedBlockCount, localSkippedBlockCount );
10871087 atomicAdd (totalBlockCount, nbIters);
10881088 }
10891089#endif
@@ -1670,7 +1670,6 @@ CUBIN_EXPORT __global__
16701670 {
16711671 return ;
16721672 }
1673- // todo: skip_softmax_attn: fix multiblockmode
16741673 bool & smemIsLastCta = smem.isLastCta ;
16751674 if (threadIdx .x == gemm1NbThrds - 1U && threadIdx .z == 0 )
16761675 {
@@ -3486,7 +3485,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
34863485#if SKIP_SOFTMAX_ATTN
34873486 float const skipSoftmaxThresholdScaleFactor,
34883487#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
3489- uint32_t * __restrict__ skipped_block_count , uint32_t * __restrict__ total_block_count ,
3488+ uint32_t * __restrict__ skippedBlockCount , uint32_t * __restrict__ totalBlockCount ,
34903489#endif
34913490#endif
34923491 uint32_t * semaphores, void * scratch, cudaStream_t stream)
@@ -3515,8 +3514,6 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
35153514 // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == nbInputSeqSplit
35163515 dim3 const dimGrid{divUp (qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize};
35173516 dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1 , 3 };
3518- // printf("dimGrid: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
3519- // printf("dimCta: %d, %d, %d\n", dimCta.x, dimCta.y, dimCta.z);
35203517 auto const launchCfg = makeLaunchConfig (dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0 );
35213518#if USE_PAGED_KV_CACHE
35223519 uint32_t const maxNbPagesPerSeq = exactDiv (maxSeqLen, tokensPerPage);
@@ -3582,7 +3579,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
35823579#if SKIP_SOFTMAX_ATTN
35833580 skipSoftmaxThresholdScaleFactor,
35843581#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
3585- skipped_block_count, total_block_count ,
3582+ skippedBlockCount, totalBlockCount ,
35863583#endif
35873584#endif
35883585 semaphores, scratch);
0 commit comments