Skip to content
Merged
12 changes: 12 additions & 0 deletions cpp/kernels/xqa/defines.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
#define SLIDING_WINDOW 0
#endif

#ifndef SKIP_SOFTMAX_ATTN
#define SKIP_SOFTMAX_ATTN 0
#endif

#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
#endif

#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
#endif

// 0 - no PDL
// 1 - naive PDL
// 2 - aggressive PDL (implemented only in mha_sm90.cu for now)
Expand Down
1 change: 1 addition & 0 deletions cpp/kernels/xqa/gmma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ __device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
asm volatile("trap;\n");
return 0;
}();
assert(__cvta_generic_to_shared(data) % baseAlign == 0);
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
return MatDesc{
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
Expand Down
45 changes: 27 additions & 18 deletions cpp/kernels/xqa/mha.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,25 @@ static constexpr auto kernel_mha = kernel_mha_impl;
#endif

#ifndef GENERATE_CUBIN
uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen)
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}

void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
Expand Down Expand Up @@ -2771,6 +2790,13 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only
uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream)
{
Expand All @@ -2793,24 +2819,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
uint32_t const nbQHeads = nbKHeads * headGrpSize;

// const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1;
uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t
{
if (!allowMultiBlockMode)
{
return 1;
}
auto const env = std::getenv("XQA_NB_SUB_SEQ");
if (env != nullptr)
{
int32_t const val = std::stoi(env);
if (val > 0)
{
return val;
}
}
return std::min<uint32_t>(
std::max<uint32_t>(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x));
}();
uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen);
// gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq
#if SPEC_DEC
const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock);
Expand Down
18 changes: 18 additions & 0 deletions cpp/kernels/xqa/mha.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ struct BeamSearchParams
// match trt-llm API.
};

uint32_t computeNbSubSeqPerSeqMHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);

void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
Expand Down Expand Up @@ -127,9 +130,18 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);

uint32_t computeNbSubSeqPerSeqHopperF8MHA(
cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen);

void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
#if SLIDING_WINDOW
uint32_t slidingWinSize,
Expand Down Expand Up @@ -167,6 +179,12 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
// int8/fp8 KV cache.
#if SPEC_DEC
SpecDecParams const& specDecParams,
#endif
#if SKIP_SOFTMAX_ATTN
float const skipSoftmaxThresholdScaleFactor,
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
#endif
#endif
uint32_t* semaphores, void* scratch, cudaStream_t stream);

Expand Down
Loading