Skip to content

Commit dbbd4c3

Browse files
committed
add support for hopper xqa skip softmax kernel
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
1 parent 5d8eaed commit dbbd4c3

File tree

9 files changed

+81
-13
lines changed

9 files changed

+81
-13
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
298298
xqaParams.use_sparse_attention = useTllmGenSparseAttention();
299299
// Skip softmax threshold.
300300
xqaParams.skip_softmax_threshold_scale_factor = mSkipSoftmaxThresholdScaleFactorDecode;
301+
#ifdef SKIP_SOFTMAX_STAT
302+
// Statistics of skip-softmax, pointers of device memory for output
303+
xqaParams.skip_softmax_total_blocks = mSkipSoftmaxTotalBlocks;
304+
xqaParams.skip_softmax_skipped_blocks = mSkipSoftmaxSkippedBlocks;
305+
#endif
301306
// Cross attention parameters.
302307
xqaParams.encoder_input_lengths = generationsParams.encoder_input_lengths;
303308

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ CubinObj CompileEngine::compile() const
105105
// scratch in this case.
106106
/*use_input_kv=*/applyRoPEInXqaKernel,
107107
/*rope_style=*/ropeStyle,
108-
/*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree};
108+
/*is_spec_dec_tree=*/mXqaParams.is_spec_dec_tree,
109+
/*use_skip_softmax_attn=*/mXqaParams.skip_softmax_threshold_scale_factor != 0};
109110
if (context.kernel_type == TLLM_XQA_JIT_MLA)
110111
{
111112
auto const& c = context;

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
232232
jit::CubinObj const* const cubinObj = mResource->getCubinObjRegistry()->getCubin(key);
233233
TLLM_CHECK(cubinObj != nullptr && cubinObj->isInitialized());
234234
bool const isSpecDec = xqaParams.multi_query_tokens;
235+
bool const isSkipSoftmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
235236
bool const isHMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kAMPERE_WARP_SPECIALIZED);
236237
bool const isGMMAKernel = (cubinObj->getKernelType() == XQAKernelType::kHOPPER_WARP_SPECIALIZED);
237238
bool const isMLAKernel = (cubinObj->getKernelType() == XQAKernelType::kSM120_MLA);
@@ -514,6 +515,15 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
514515
appendParam(&specDecParams);
515516
specDecBlocks = divUp(specDecParams.qSeqLen, 64 / num_q_heads_over_kv);
516517
}
518+
if (isSkipSoftmax)
519+
{
520+
TLLM_CHECK_WITH_INFO(isGMMAKernel, "skip softmax is only supported for GMMA kernel in JIT path for now.");
521+
appendParam(&xqaParams.skip_softmax_threshold_scale_factor);
522+
#ifdef SKIP_SOFTMAX_STAT
523+
appendParam(&xqaParams.skip_softmax_total_blocks);
524+
appendParam(&xqaParams.skip_softmax_skipped_blocks);
525+
#endif
526+
}
517527
appendParam(&launchParams.semaphores);
518528
appendParam(&launchParams.scratch);
519529
kernelParams[idxNextParam] = nullptr; // one extra nullptr at end as guard.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,16 @@ bool supportConfigQGMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlu
9696
{
9797
return false;
9898
}
99-
if (xqaParams.kv_cache_data_type != DATA_TYPE_E4M3)
99+
if (!contains({DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_E4M3}, xqaParams.kv_cache_data_type))
100100
{
101101
return false;
102102
}
103+
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
104+
if (!is_skip_softmax && xqaParams.kv_cache_data_type != DATA_TYPE_E4M3)
105+
{
106+
// Only use hopper kernel with fp16/bf16 kv cache data type when skip softmax is enabled
107+
return false;
108+
}
103109
if (xqaParams.beam_width != 1)
104110
{
105111
return false;
@@ -168,6 +174,11 @@ bool supportConfigHMMA(XQAParams const& xqaParams, int SM, bool forConfigurePlug
168174
{
169175
return false;
170176
}
177+
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
178+
if (is_skip_softmax)
179+
{
180+
return false;
181+
}
171182
return true;
172183
}
173184

@@ -201,6 +212,11 @@ bool supportConfigMLA(XQAParams const& xqaParams, int SM, bool forConfigurePlugi
201212
{
202213
return false;
203214
}
215+
bool const is_skip_softmax = xqaParams.skip_softmax_threshold_scale_factor != 0;
216+
if (is_skip_softmax)
217+
{
218+
return false;
219+
}
204220
return true;
205221
}
206222

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ extern "C"
6666

6767
bool is_spec_dec_tree
6868
= true; // useful only when multi_query_tokens, should be true unless using linear tree in spec-dec.
69+
bool use_skip_softmax_attn;
6970
} tllmXqaJitContext;
7071

7172
// tllmXqaJitProgram is an opaque handle for a program.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ tllmXqaJitStatus getMacroFlags(tllmXqaJitContext const* context, std::vector<std
215215
macros["USE_INPUT_KV"] = context->use_input_kv ? "1" : "0";
216216
macros["ROPE_STYLE"] = std::to_string(int(context->rope_style));
217217
macros["IS_SPEC_DEC_TREE"] = context->is_spec_dec_tree ? "1" : "0";
218+
macros["SKIP_SOFTMAX_ATTN"] = context->use_skip_softmax_attn ? "1" : "0";
219+
#ifdef SKIP_SOFTMAX_STAT
220+
macros["SKIP_SOFTMAX_ATTN_BLOCK_STATS"] = context->use_skip_softmax_attn ? "1" : "0";
221+
#endif
218222

219223
// Without these macros, NVRTC uses precompiled headers for cuda_fp16.h etc.
220224
// Linking might fail due to ABI incompatibility.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,10 @@ bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool forCo
493493
{
494494
SUPPORT_RETURN_FALSE("streaming-llm");
495495
}
496+
if (xqaParams.skip_softmax_threshold_scale_factor != 0)
497+
{
498+
SUPPORT_RETURN_FALSE("skip_softmax_threshold_scale_factor");
499+
}
496500

497501
// OPTIMIZE: For the standard generation-phase MHA, there are still extra limitations.
498502
// NOTE: Medusa mode = Multi_query_tokens > 1.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,21 @@ CUtensorMapSwizzle getSwizzleMode(uint32_t partBytes)
6464
}
6565
};
6666

67+
CUtensorMapDataType_enum getDataTypeFromXqaParams(XQAParams const& xqaParams)
68+
{
69+
if (xqaParams.kv_cache_data_type == DATA_TYPE_BF16)
70+
{
71+
return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
72+
}
73+
else if (xqaParams.kv_cache_data_type == DATA_TYPE_FP16)
74+
{
75+
return CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
76+
}
77+
TLLM_CHECK(xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 || xqaParams.kv_cache_data_type == DATA_TYPE_E5M2
78+
|| xqaParams.kv_cache_data_type == DATA_TYPE_INT8);
79+
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
80+
}
81+
6782
CUtensorMap makeTensorMapForQ(std::shared_ptr<CUDADriverWrapper> const& driver, void const* addr,
6883
CUtensorMapDataType_enum dataType, uint32_t headElems, uint32_t totalNbHeads, uint32_t partElems, uint32_t boxHeads)
6984
{
@@ -131,24 +146,26 @@ CUtensorMap makeTensorMapForHopperXqaKVCache(
131146
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
132147
{
133148
uint32_t const headElems = xqaParams.head_size;
134-
uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8);
149+
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
150+
uint32_t const elemBytes = getElemBytes(dataType);
135151
TLLM_CHECK(headElems <= 256);
136152
uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256);
137153
uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes;
138-
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8,
139-
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
154+
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size,
155+
xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
140156
}
141157
else
142158
{
143159
static_assert(std::is_same_v<KVCacheBuffer, KVLinearBuffer>);
144160
uint32_t const headElems = xqaParams.head_size;
145-
uint32_t const elemBytes = getElemBytes(CU_TENSOR_MAP_DATA_TYPE_UINT8);
161+
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
162+
uint32_t const elemBytes = getElemBytes(dataType);
146163
TLLM_CHECK(headElems <= 256);
147164
uint32_t const paddedHeadElems = headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256);
148165
uint32_t const partElems = std::min(elemBytes * paddedHeadElems, 128U) / elemBytes;
149-
return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, CU_TENSOR_MAP_DATA_TYPE_UINT8,
150-
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width,
151-
xqaParams.batch_size, partElems);
166+
return makeTensorMapForContiguousKVCache(driver, kv_cache_buffer.data, dataType, xqaParams.head_size,
167+
xqaParams.num_kv_heads, xqaParams.max_attention_window_size, xqaParams.beam_width, xqaParams.batch_size,
168+
partElems);
152169
}
153170
}
154171

@@ -161,11 +178,12 @@ template <typename KVCacheBuffer>
161178
CUtensorMap makeTensorMapForXqaMlaKVCache(std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver,
162179
XQAParams const& xqaParams, KVCacheBuffer const& kv_cache_buffer, bool forK)
163180
{
181+
CUtensorMapDataType_enum const dataType = getDataTypeFromXqaParams(xqaParams);
164182
uint32_t const partElems = (forK ? 64 : 128);
165183
if constexpr (std::is_same_v<KVCacheBuffer, KVBlockArray>)
166184
{
167-
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, CU_TENSOR_MAP_DATA_TYPE_UINT8,
168-
xqaParams.head_size, xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
185+
return makeTensorMapForPagedKVCache(driver, kv_cache_buffer.mPrimaryPoolPtr, dataType, xqaParams.head_size,
186+
xqaParams.num_kv_heads, xqaParams.tokens_per_block, partElems);
169187
}
170188
else
171189
{
@@ -183,7 +201,7 @@ CUtensorMap makeTensorMapForXqaMlaQ(
183201
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> const& driver, XQAParams const& xqaParams, void const* q)
184202
{
185203
uint32_t const partElems = 64;
186-
return makeTensorMapForQ(driver, q, CU_TENSOR_MAP_DATA_TYPE_UINT8, xqaParams.head_size,
204+
return makeTensorMapForQ(driver, q, getDataTypeFromXqaParams(xqaParams), xqaParams.head_size,
187205
xqaParams.num_q_heads * xqaParams.total_num_input_tokens, partElems, xqaParams.num_q_heads);
188206
}
189207
} // namespace kernels

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,12 @@ struct XQAParams
119119
bool use_sparse_attention = false;
120120

121121
// Skip softmax threshold.
122-
float skip_softmax_threshold_scale_factor = 0.0f;
122+
float skip_softmax_threshold_scale_factor = 0;
123+
124+
#ifdef SKIP_SOFTMAX_STAT
125+
uint32_t* skip_softmax_total_blocks = nullptr;
126+
uint32_t* skip_softmax_skipped_blocks = nullptr;
127+
#endif
123128

124129
cudaStream_t stream = 0;
125130
// layer index
@@ -199,6 +204,10 @@ struct XQAParams
199204
<< "sparse_params: " << sparse_params.toString() << std::endl
200205
<< "use_sparse_attention :" << (use_sparse_attention ? "true" : "false") << std ::endl
201206
<< "skip_softmax_threshold_scale_factor :" << skip_softmax_threshold_scale_factor << std ::endl
207+
#ifdef SKIP_SOFTMAX_STAT
208+
<< "skip_softmax_total_blocks :" << skip_softmax_total_blocks << std ::endl
209+
<< "skip_softmax_skipped_blocks :" << skip_softmax_skipped_blocks << std ::endl
210+
#endif
202211
<< "stream :" << stream;
203212

204213
return ss.str();

0 commit comments

Comments
 (0)