Skip to content

Commit 490bb60

Browse files
committed
disable skip for short seqs
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
1 parent 2202e65 commit 490bb60

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

cpp/kernels/xqa/mha_sm90.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,8 @@ CUBIN_EXPORT __global__
786786
static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2);
787787
assert(isMultiBlockMode == (nbSubSeq > 1));
788788
#if SKIP_SOFTMAX_ATTN
789-
float const skipSoftmaxThreshold = skipSoftmaxThresholdScaleFactor / cacheSeqLen;
789+
bool const disableSkipForShortSeq = (cacheSeqLen < skipSoftmaxThresholdScaleFactor);
790+
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / cacheSeqLen;
790791
#endif
791792
if (idxSubSeq >= nbSubSeq)
792793
{
@@ -1001,7 +1002,7 @@ CUBIN_EXPORT __global__
10011002
auto& skipSoftmaxXBar = smem.skipSoftmaxXBar[idxXBuf];
10021003
skipSoftmaxXBar.consumed.arrive_and_wait();
10031004

1004-
bool const maybeSkip = idxIter != 0;
1005+
bool const maybeSkip = !disableSkipForShortSeq && idxIter != 0;
10051006
RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc,
10061007
skipSoftmaxThreshold, &smem.skipSoftmaxVotesGemm0ToV[idxXBuf], maybeSkip);
10071008
bool const shouldSkipSoftmaxAttn = static_cast<bool>(smem.skipSoftmaxVotesGemm0ToV[idxXBuf]);
@@ -2142,7 +2143,8 @@ __device__ inline RegColWiseVec computeWarpGrpColMax_sync(
21422143
{
21432144
*smemSkipVote = maybeSkip ? 1U : 0U; // will sync before vote
21442145
}
2145-
float const lnThreshold = log(skipSoftmaxThreshold);
2146+
float const lnThreshold
2147+
= log(skipSoftmaxThreshold); // this can be -inf, but should be safe as we only use it for comparison
21462148
#endif
21472149

21482150
auto colMax = RegColWiseVec::filled(Vec<float, 2>::filled(safeInitRowMax));

cpp/kernels/xqa/test/refAttention.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
6969
{
7070
skipRowMaxs[i].fill(-INFINITY);
7171
}
72-
float skipSoftmaxThreshold = skipSoftmaxThresholdScaleFactor / seqLen;
72+
bool const disableSkipForShortSeq = (seqLen < skipSoftmaxThresholdScaleFactor);
73+
float const skipSoftmaxThreshold = disableSkipForShortSeq ? 0.0f : skipSoftmaxThresholdScaleFactor / seqLen;
7374

7475
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
7576
{
@@ -103,7 +104,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
103104
auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq];
104105
skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval();
105106

106-
if (skipSoftmaxThreshold > 0)
107+
if (!disableSkipForShortSeq && skipSoftmaxThreshold > 0)
107108
{
108109
*totalBlockCount += 1;
109110
auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold));

cpp/kernels/xqa/test/test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,10 +1312,11 @@ TEST(RefCheck, llama_V2_70b)
13121312
#endif
13131313
#if SKIP_SOFTMAX_ATTN
13141314
runTest<1>(32, 2048, false, true, false, false, false, ~0U, 1U << 30, 0.f);
1315-
runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 55.f);
1315+
runTest<4>(32, 1538, false, true, false, false, false, ~0U, 1U << 30, 1280.f);
13161316
runTest<2>(32, 4096, false, true, false, false, false, ~0U, 1U << 30, 125.f);
13171317
runTest<4>(32, 300, false, true, false, false, false, ~0U, 1U << 30, 80.f);
1318-
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 455.f);
1318+
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 501.0f);
1319+
runTest<4>(32, 500, false, true, false, false, false, ~0U, 1U << 30, 500.f);
13191320
#endif
13201321
runTest<8>(120, 367, false, true);
13211322
runTest<8>(1792, 2048, false, true);

0 commit comments

Comments
 (0)