Skip to content

Commit c0e25e5

Browse files
[TRTLLM-10022][feat] Add hopper xqa decode support for skip softmax attention (#10264)
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
1 parent c5d5af9 commit c0e25e5

File tree

18 files changed

+643
-217
lines changed

18 files changed

+643
-217
lines changed

cpp/kernels/xqa/defines.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
129129
#define SLIDING_WINDOW 0
130130
#endif
131131

132+
#ifndef SKIP_SOFTMAX_ATTN
133+
#define SKIP_SOFTMAX_ATTN 0
134+
#endif
135+
136+
#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
137+
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
138+
#endif
139+
140+
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
141+
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
142+
#endif
143+
132144
// 0 - no PDL
133145
// 1 - naive PDL
134146
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)

cpp/kernels/xqa/gmma.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
106106
asm volatile("trap;\n");
107107
return 0;
108108
}();
109+
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
109110
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
110111
return MatDesc{
111112
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),

cpp/kernels/xqa/mha.cu

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
27342734
#endif
27352735

27362736
#ifndef GENERATE_CUBIN
2737+
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
2738+
{
2739+
if (!allowMultiBlockMode)
2740+
{
2741+
return 1;
2742+
}
2743+
auto const env = std::getenv("XQA_NB_SUB_SEQ");
2744+
if (env != nullptr)
2745+
{
2746+
int32_t const val = std::stoi(env);
2747+
if (val > 0)
2748+
{
2749+
return val;
2750+
}
2751+
}
2752+
return std::min<uint32_t>(
2753+
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
2754+
}
2755+
27372756
void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
27382757
#if SLIDING_WINDOW
27392758
uint32_t slidingWinSize,
@@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
27712790
// int8/fp8 KV cache.
27722791
#if SPEC_DEC
27732792
SpecDecParams const& specDecParams,
2793+
#endif
2794+
#if SKIP_SOFTMAX_ATTN
2795+
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
2796+
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
2797+
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
2798+
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
2799+
#endif
27742800
#endif
27752801
uint32_t* semaphores, void* scratch, cudaStream_t stream)
27762802
{
@@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
27932819
uint32_t const nbQHeads = nbKHeads * headGrpSize;
27942820

27952821
// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
2796-
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
2797-
{
2798-
if (!allowMultiBlockMode)
2799-
{
2800-
return 1;
2801-
}
2802-
auto const env = std::getenv("XQA_NB_SUB_SEQ");
2803-
if (env != nullptr)
2804-
{
2805-
int32_t const val = std::stoi(env);
2806-
if (val > 0)
2807-
{
2808-
return val;
2809-
}
2810-
}
2811-
return std::min<uint32_t>(
2812-
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
2813-
}();
2822+
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
28142823
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
28152824
#if SPEC_DEC
28162825
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);

cpp/kernels/xqa/mha.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ struct BeamSearchParams
9090
// match trt-llm API.
9191
};
9292

93+
uint32_t computeNbSubSeqPerSeqMHA(
94+
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
95+
9396
void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
9497
#if SLIDING_WINDOW
9598
uint32_t slidingWinSize,
@@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
127130
// int8/fp8 KV cache.
128131
#if SPEC_DEC
129132
SpecDecParams const& specDecParams,
133+
#endif
134+
#if SKIP_SOFTMAX_ATTN
135+
float const skipSoftmaxThresholdScaleFactor,
136+
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
137+
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
138+
#endif
130139
#endif
131140
uint32_t* semaphores, void* scratch, cudaStream_t stream);
132141

142+
uint32_t computeNbSubSeqPerSeqHopperF8MHA(
143+
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);
144+
133145
void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
134146
#if SLIDING_WINDOW
135147
uint32_t slidingWinSize,
@@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
167179
// int8/fp8 KV cache.
168180
#if SPEC_DEC
169181
SpecDecParams const& specDecParams,
182+
#endif
183+
#if SKIP_SOFTMAX_ATTN
184+
float const skipSoftmaxThresholdScaleFactor,
185+
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
186+
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
187+
#endif
170188
#endif
171189
uint32_t* semaphores, void* scratch, cudaStream_t stream);
172190

0 commit comments

Comments
 (0)