Skip to content

Commit 2202e65

Browse files
committed
fix spelling, kernel param number error and cleanup
Signed-off-by: Pengbo Wang <[email protected]>
1 parent dbbd4c3 commit 2202e65

4 files changed

Lines changed: 11 additions & 15 deletions

File tree

cpp/kernels/xqa/mha.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,8 +2794,8 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
27942794
#if SKIP_SOFTMAX_ATTN
27952795
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
27962796
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
2797-
uint32_t* __restrict__ skipped_block_count, // for compatibility with mha_sm90.cu only
2798-
uint32_t* __restrict__ total_block_count, // for compatibility with mha_sm90.cu only
2797+
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
2798+
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
27992799
#endif
28002800
#endif
28012801
uint32_t* semaphores, void* scratch, cudaStream_t stream)

cpp/kernels/xqa/mha_sm90.cu

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ CUBIN_EXPORT __global__
868868
#endif
869869

870870
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
871-
uint32_t local_skipped_block_count = 0;
871+
uint32_t localSkippedBlockCount = 0;
872872
#endif
873873

874874
// QK gemm
@@ -1014,7 +1014,7 @@ CUBIN_EXPORT __global__
10141014
{
10151015
smem.skipSoftmaxVotesGemm0ToGemm1[idxXBuf] = 1U;
10161016
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
1017-
local_skipped_block_count++;
1017+
localSkippedBlockCount++;
10181018
#endif
10191019
}
10201020
asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used
@@ -1081,9 +1081,9 @@ CUBIN_EXPORT __global__
10811081
unused(xBar.produced.arrive());
10821082
}
10831083
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
1084-
if (threadIdx.x == 0 && skipped_block_count != nullptr && total_block_count != nullptr)
1084+
if (threadIdx.x == 0 && skippedBlockCount != nullptr && totalBlockCount != nullptr)
10851085
{
1086-
atomicAdd(skippedBlockCount, local_skipped_block_count);
1086+
atomicAdd(skippedBlockCount, localSkippedBlockCount);
10871087
atomicAdd(totalBlockCount, nbIters);
10881088
}
10891089
#endif
@@ -1670,7 +1670,6 @@ CUBIN_EXPORT __global__
16701670
{
16711671
return;
16721672
}
1673-
// todo: skip_softmax_attn: fix multiblockmode
16741673
bool& smemIsLastCta = smem.isLastCta;
16751674
if (threadIdx.x == gemm1NbThrds - 1U && threadIdx.z == 0)
16761675
{
@@ -3486,7 +3485,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
34863485
#if SKIP_SOFTMAX_ATTN
34873486
float const skipSoftmaxThresholdScaleFactor,
34883487
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
3489-
uint32_t* __restrict__ skipped_block_count, uint32_t* __restrict__ total_block_count,
3488+
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
34903489
#endif
34913490
#endif
34923491
uint32_t* semaphores, void* scratch, cudaStream_t stream)
@@ -3515,8 +3514,6 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
35153514
// gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == nbInputSeqSplit
35163515
dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize};
35173516
dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3};
3518-
// printf("dimGrid: %d, %d, %d\n", dimGrid.x, dimGrid.y, dimGrid.z);
3519-
// printf("dimCta: %d, %d, %d\n", dimCta.x, dimCta.y, dimCta.z);
35203517
auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0);
35213518
#if USE_PAGED_KV_CACHE
35223519
uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage);
@@ -3582,7 +3579,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
35823579
#if SKIP_SOFTMAX_ATTN
35833580
skipSoftmaxThresholdScaleFactor,
35843581
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
3585-
skipped_block_count, total_block_count,
3582+
skippedBlockCount, totalBlockCount,
35863583
#endif
35873584
#endif
35883585
semaphores, scratch);

cpp/kernels/xqa/test/refAttention.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
#include "refAttention.h"
1919
#include <cstdint>
20-
#include <cstdio>
2120

2221
template <typename T>
2322
Vec<float, validElemsPerHead> toF32Head(Vec<T, validElemsPerHead> const& src)
@@ -65,7 +64,6 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
6564
uint32_t const idxTileBeg = seqBeg / tileSize;
6665

6766
uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1;
68-
// uint32_t const nbSubSeq = 1;
6967
std::vector<Eigen::Vector<float, headGrpSize>> skipRowMaxs(nbSubSeq);
7068
for (uint32_t i = 0; i < nbSubSeq; i++)
7169
{

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
379379
.mask = reinterpret_cast<SpecDecParams::MaskType const*>(xqaParams.spec_decoding_packed_mask)};
380380
};
381381

382-
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 16;
382+
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 19;
383383
uint32_t idxNextParam = 0;
384384
void* kernelParams[kMAX_NB_KERNEL_PARAMS];
385385
auto appendParam = [&](auto* p) mutable
@@ -517,7 +517,8 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
517517
}
518518
if (isSkipSoftmax)
519519
{
520-
TLLM_CHECK_WITH_INFO(isGMMAKernel, "skip softmax is only supported for GMMA kernel in JIT path for now.");
520+
TLLM_CHECK_WITH_INFO(isGMMAKernel, "skip softmax is only supported for GMMA kernel for now.");
521+
TLLM_CHECK_WITH_INFO(!isSpecDec, "skip softmax is not supported with spec dec for now.");
521522
appendParam(&xqaParams.skip_softmax_threshold_scale_factor);
522523
#ifdef SKIP_SOFTMAX_STAT
523524
appendParam(&xqaParams.skip_softmax_total_blocks);

0 commit comments

Comments
 (0)