Skip to content

Commit 5d8eaed

Browse files
committed
fix nvrtc compile and code style
Signed-off-by: Pengbo Wang <[email protected]>
1 parent 2b2d1ed commit 5d8eaed

File tree

7 files changed

+53
-53
lines changed

7 files changed

+53
-53
lines changed

cpp/kernels/xqa/defines.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is ena
137137
#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0
138138
#endif
139139

140-
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GRETAER_THAN_ONE
141-
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GRETAER_THAN_ONE 1
140+
#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE
141+
#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1
142142
#endif
143143

144144
// 0 - no PDL

cpp/kernels/xqa/mha.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads,
134134
#if SKIP_SOFTMAX_ATTN
135135
float const skipSoftmaxThresholdScaleFactor,
136136
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
137-
uint32_t* __restrict__ skipped_block_count, uint32_t* __restrict__ total_block_count,
137+
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
138138
#endif
139139
#endif
140140
uint32_t* semaphores, void* scratch, cudaStream_t stream);
@@ -183,7 +183,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads,
183183
#if SKIP_SOFTMAX_ATTN
184184
float const skipSoftmaxThresholdScaleFactor,
185185
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
186-
uint32_t* __restrict__ skipped_block_count, uint32_t* __restrict__ total_block_count,
186+
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
187187
#endif
188188
#endif
189189
uint32_t* semaphores, void* scratch, cudaStream_t stream);

cpp/kernels/xqa/mha_sm90.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ CUBIN_EXPORT __global__
705705
#if SKIP_SOFTMAX_ATTN
706706
float const skipSoftmaxThresholdScaleFactor,
707707
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
708-
uint32_t* __restrict__ skipped_block_count, uint32_t* __restrict__ total_block_count,
708+
uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
709709
#endif
710710
#endif
711711
uint32_t* __restrict__ const semaphores
@@ -1083,8 +1083,8 @@ CUBIN_EXPORT __global__
10831083
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
10841084
if (threadIdx.x == 0 && skipped_block_count != nullptr && total_block_count != nullptr)
10851085
{
1086-
atomicAdd(skipped_block_count, local_skipped_block_count);
1087-
atomicAdd(total_block_count, nbIters);
1086+
atomicAdd(skippedBlockCount, local_skipped_block_count);
1087+
atomicAdd(totalBlockCount, nbIters);
10881088
}
10891089
#endif
10901090
unused(smem.qBar.consumed.arrive());
@@ -2395,7 +2395,7 @@ __device__ inline void storeGemm0AccToShm(
23952395
uint32_t const idxOctInsideHalf = idxInHalf / 8;
23962396
uint32_t const idxRowInsideOct = lane % 8;
23972397
uint32_t const warpBaseC = 16 * warpRank;
2398-
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair<uint32_t, uint32_t>
2398+
auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> mha::pair<uint32_t, uint32_t>
23992399
{
24002400
uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols;
24012401
uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols;

cpp/kernels/xqa/mha_stdheaders.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,19 @@ using is_void = is_same<remove_cv_t<T>, void>;
12721272
template <typename T>
12731273
inline constexpr bool is_void_v = is_void<T>::value;
12741274
#endif
1275+
1276+
#ifndef GENERATE_CUBIN
1277+
template <typename T1, typename T2>
1278+
using pair = std::pair<T1, T2>;
1279+
#else
1280+
template <typename T1, typename T2>
1281+
struct pair
1282+
{
1283+
T1 first;
1284+
T2 second;
1285+
};
1286+
#endif
1287+
12751288
} // namespace mha
12761289

12771290
#if GENERATE_CUBIN

cpp/kernels/xqa/test/refAttention.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ using Vector = Matrix<Type, Size, 1>;
5151
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
5252
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
5353
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
54-
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks,
55-
float skip_softmax_threshold_scale_factor, uint32_t* skipped_block_count, uint32_t* total_block_count,
56-
uint32_t multi_block_num)
54+
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
55+
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
5756
{
5857
uint32_t const nbTiles = divUp(seqLen, tileSize);
5958
auto gemm1Acc = Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor>::Zero().eval();
@@ -65,14 +64,14 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
6564
uint32_t const seqBeg = (seqLen < slidingWinSize ? 0 : seqLen - slidingWinSize);
6665
uint32_t const idxTileBeg = seqBeg / tileSize;
6766

68-
uint32_t const nbSubSeq = (multi_block_num > 0 && nbTiles >= 2) ? mha::min(nbTiles, multi_block_num) : 1;
67+
uint32_t const nbSubSeq = (multiBlockNum > 0 && nbTiles >= 2) ? mha::min(nbTiles, multiBlockNum) : 1;
6968
// uint32_t const nbSubSeq = 1;
7069
std::vector<Eigen::Vector<float, headGrpSize>> skipRowMaxs(nbSubSeq);
7170
for (uint32_t i = 0; i < nbSubSeq; i++)
7271
{
7372
skipRowMaxs[i].fill(-INFINITY);
7473
}
75-
float skip_softmax_threshold = skip_softmax_threshold_scale_factor / seqLen;
74+
float skipSoftmaxThreshold = skipSoftmaxThresholdScaleFactor / seqLen;
7675

7776
for (uint32_t idxTile = idxTileBeg; idxTile < nbTiles; idxTile++)
7877
{
@@ -106,17 +105,14 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
106105
auto const prevSkipRowMax = skipRowMaxs[idxTile % nbSubSeq];
107106
skipRowMaxs[idxTile % nbSubSeq] = localRowMax.cwiseMax(skipRowMaxs[idxTile % nbSubSeq]).eval();
108107

109-
// printf("\n===================\n");
110-
111-
// add skip softmax threshold here
112-
if (skip_softmax_threshold > 0)
108+
if (skipSoftmaxThreshold > 0)
113109
{
114-
*total_block_count += 1;
115-
auto const skip_softmax_mask = ((localRowMax - prevSkipRowMax).array() < std::log(skip_softmax_threshold));
116-
bool const skip_block = skip_softmax_mask.all() && ((idxTile - idxTileBeg) >= nbSubSeq);
117-
if (skip_block)
110+
*totalBlockCount += 1;
111+
auto const skipSoftmaxMask = ((localRowMax - prevSkipRowMax).array() < std::log(skipSoftmaxThreshold));
112+
bool const skipBlock = skipSoftmaxMask.all() && ((idxTile - idxTileBeg) >= nbSubSeq);
113+
if (skipBlock)
118114
{
119-
*skipped_block_count += 1;
115+
*skippedBlockCount += 1;
120116
continue;
121117
}
122118
}
@@ -170,8 +166,7 @@ Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAt
170166
refFlashAttention<prec, tileSize, isPaged, useBeamSearch>(IOHead const* q, \
171167
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, \
172168
float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, \
173-
float skip_softmax_threshold, uint32_t* skipped_block_count, uint32_t* total_block_count, \
174-
uint32_t multi_block_num)
169+
float skipSoftmaxThreshold, uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum)
175170

176171
INSTANTIATE_refFlashAttention(CacheElem, 64, false, false);
177172
INSTANTIATE_refFlashAttention(CacheElem, 64, false, true);

cpp/kernels/xqa/test/refAttention.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ struct CacheSeq<true, true>
8888
template <typename MathElem, uint32_t tileSize, bool isPaged, bool useBeamSearch>
8989
Eigen::Matrix<float, headGrpSize, validElemsPerHead, Eigen::RowMajor> refFlashAttention(IOHead const* q,
9090
CacheSeq<isPaged, useBeamSearch> const& k, CacheSeq<isPaged, useBeamSearch> const& v, uint32_t seqLen, float qScale,
91-
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skip_softmax_threshold,
92-
uint32_t* skipped_block_count, uint32_t* total_block_count, uint32_t multi_block_num);
91+
float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks, float skipSoftmaxThresholdScaleFactor,
92+
uint32_t* skippedBlockCount, uint32_t* totalBlockCount, uint32_t multiBlockNum);
9393

9494
template <typename MathElem, bool isPaged, bool useBeamSearch>
9595
#if SPEC_DEC

cpp/kernels/xqa/test/test.cpp

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
225225
seqLen = (16U << 20) / gmemCacheHeadBytes; // 32MB per K+V head.
226226
}
227227
ctxLen = std::min(ctxLen, seqLen);
228-
float skip_softmax_threshold_scale_factor = skipSoftmaxThresholdScaleFactor;
229-
uint32_t skipped_block_count = 0;
230-
uint32_t total_block_count = 0;
231-
if (skip_softmax_threshold_scale_factor > 0)
228+
uint32_t skippedBlockCount = 0;
229+
uint32_t totalBlockCount = 0;
230+
if (skipSoftmaxThresholdScaleFactor > 0)
232231
{
233232
assert(useQGMMA);
234233
}
@@ -339,10 +338,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
339338
auto const ctxLenList = ManagedMemBuf<uint32_t[beamWidth]>(batchSize);
340339
#if SKIP_SOFTMAX_ATTN
341340
#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS
342-
auto const kernel_skipped_block_count = ManagedMemBuf<uint32_t>(1);
343-
auto const kernel_total_block_count = ManagedMemBuf<uint32_t>(1);
344-
kernel_skipped_block_count[0] = 0;
345-
kernel_total_block_count[0] = 0;
341+
auto const kernelSkippedBlockCount = ManagedMemBuf<uint32_t>(1);
342+
auto const kernelTotalBlockCount = ManagedMemBuf<uint32_t>(1);
343+
kernelSkippedBlockCount[0] = 0;
344+
kernelTotalBlockCount[0] = 0;
346345
#endif
347346
#else
348347
EXPECT_EQ(skipSoftmaxThresholdScaleFactor, 0.0f)
@@ -804,7 +803,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
804803
#if SKIP_SOFTMAX_ATTN
805804
skipSoftmaxThresholdScaleFactor,
806805
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
807-
kernel_skipped_block_count.get(), kernel_total_block_count.get(),
806+
kernelSkippedBlockCount.get(), kernelTotalBlockCount.get(),
808807
#endif
809808
#endif
810809
semaphores.get(), scratch, stream);
@@ -844,8 +843,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
844843
prefetchToDevice(cudaCpuDeviceId);
845844
checkCuda(cudaStreamSynchronize(stream));
846845
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
847-
kernel_skipped_block_count[0] /= nbIters;
848-
kernel_total_block_count[0] /= nbIters;
846+
kernelSkippedBlockCount[0] /= nbIters;
847+
kernelTotalBlockCount[0] /= nbIters;
849848
#endif
850849
if (testPerf)
851850
{
@@ -885,7 +884,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
885884
float const dramSolRatio = dramSolTime / ms;
886885
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
887886
size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
888-
* (nbKHeads + nbVHeads * (1 - 1.0f * kernel_skipped_block_count[0] / kernel_total_block_count[0]))
887+
* (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]))
889888
* nbLoadedCacheTokens;
890889
float const totalTrafficWithSkip
891890
= totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices.
@@ -907,13 +906,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
907906
float const tops = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
908907
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
909908
#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
910-
float const topsWithSkip = headGrpSize * qSeqLen * float(seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
911-
* nbKHeads * batchSize / (ms * 1E-3F) * 1E-12F;
912-
printf("kernel skipped_block_count: %d/%d (%.2f%%)\n", kernel_skipped_block_count[0],
913-
kernel_total_block_count[0],
914-
kernel_total_block_count[0] == 0 ? 0.0f
915-
: 100.0f * kernel_skipped_block_count[0] / kernel_total_block_count[0]);
916-
printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, topsWithSkip);
909+
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
910+
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
911+
printf("dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n", dramSolRatioWithSkip * 100, ms, tops);
917912
#else
918913
printf("dramSolRatio: %f%% (%f ms, TOPS = %f)\n", dramSolRatio * 100, ms, tops);
919914
#endif
@@ -1138,8 +1133,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
11381133
{
11391134
refOutput = refFlashAttention<CacheElem, 64>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
11401135
vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks,
1141-
skip_softmax_threshold_scale_factor, &skipped_block_count, &total_block_count,
1142-
multiBlockNum);
1136+
skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum);
11431137
// refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
11441138
// vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
11451139
}
@@ -1187,13 +1181,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
11871181
}
11881182
}
11891183
#if SKIP_SOFTMAX_ATTN
1190-
printf("host skipped_block_count: %d/%d (%.2f%%)\n", skipped_block_count, total_block_count,
1191-
total_block_count == 0 ? 0.0f : 100.0f * skipped_block_count / total_block_count);
1184+
printf("host skippedBlockCount: %d/%d (%.2f%%)\n", skippedBlockCount, totalBlockCount,
1185+
totalBlockCount == 0 ? 0.0f : 100.0f * skippedBlockCount / totalBlockCount);
11921186
#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
1193-
printf("kernel skipped_block_count: %d/%d (%.2f%%)\n", kernel_skipped_block_count[0],
1194-
kernel_total_block_count[0],
1195-
kernel_total_block_count[0] == 0 ? 0.0f
1196-
: 100.0f * kernel_skipped_block_count[0] / kernel_total_block_count[0]);
1187+
printf("kernel skippedBlockCount: %d/%d (%.2f%%)\n", kernelSkippedBlockCount[0], kernelTotalBlockCount[0],
1188+
kernelTotalBlockCount[0] == 0 ? 0.0f : 100.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]);
11971189
#endif
11981190
#endif
11991191
if (saveData)

0 commit comments

Comments
 (0)