@@ -225,10 +225,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
225225 seqLen = (16U << 20 ) / gmemCacheHeadBytes; // 32MB per K+V head.
226226 }
227227 ctxLen = std::min (ctxLen, seqLen);
228- float skip_softmax_threshold_scale_factor = skipSoftmaxThresholdScaleFactor;
229- uint32_t skipped_block_count = 0 ;
230- uint32_t total_block_count = 0 ;
231- if (skip_softmax_threshold_scale_factor > 0 )
228+ uint32_t skippedBlockCount = 0 ;
229+ uint32_t totalBlockCount = 0 ;
230+ if (skipSoftmaxThresholdScaleFactor > 0 )
232231 {
233232 assert (useQGMMA);
234233 }
@@ -339,10 +338,10 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
339338 auto const ctxLenList = ManagedMemBuf<uint32_t [beamWidth]>(batchSize);
340339#if SKIP_SOFTMAX_ATTN
341340#ifdef SKIP_SOFTMAX_ATTN_BLOCK_STATS
342- auto const kernel_skipped_block_count = ManagedMemBuf<uint32_t >(1 );
343- auto const kernel_total_block_count = ManagedMemBuf<uint32_t >(1 );
344- kernel_skipped_block_count [0 ] = 0 ;
345- kernel_total_block_count [0 ] = 0 ;
341+ auto const kernelSkippedBlockCount = ManagedMemBuf<uint32_t >(1 );
342+ auto const kernelTotalBlockCount = ManagedMemBuf<uint32_t >(1 );
343+ kernelSkippedBlockCount [0 ] = 0 ;
344+ kernelTotalBlockCount [0 ] = 0 ;
346345#endif
347346#else
348347 EXPECT_EQ (skipSoftmaxThresholdScaleFactor, 0 .0f )
@@ -804,7 +803,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
804803#if SKIP_SOFTMAX_ATTN
805804 skipSoftmaxThresholdScaleFactor,
806805#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
807- kernel_skipped_block_count .get (), kernel_total_block_count .get (),
806+ kernelSkippedBlockCount .get (), kernelTotalBlockCount .get (),
808807#endif
809808#endif
810809 semaphores.get (), scratch, stream);
@@ -844,8 +843,8 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
844843 prefetchToDevice (cudaCpuDeviceId);
845844 checkCuda (cudaStreamSynchronize (stream));
846845#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
847- kernel_skipped_block_count [0 ] /= nbIters;
848- kernel_total_block_count [0 ] /= nbIters;
846+ kernelSkippedBlockCount [0 ] /= nbIters;
847+ kernelTotalBlockCount [0 ] /= nbIters;
849848#endif
850849 if (testPerf)
851850 {
@@ -885,7 +884,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
885884 float const dramSolRatio = dramSolTime / ms;
886885#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
887886 size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
888- * (nbKHeads + nbVHeads * (1 - 1 .0f * kernel_skipped_block_count [0 ] / kernel_total_block_count [0 ]))
887+ * (nbKHeads + nbVHeads * (1 - 1 .0f * kernelSkippedBlockCount [0 ] / kernelTotalBlockCount [0 ]))
889888 * nbLoadedCacheTokens;
890889 float const totalTrafficWithSkip
891890 = totalNbCacheLoadWithSkip + inputBytes + outputBytes; // we ignore page indices and beam search indices.
@@ -907,13 +906,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
907906 float const tops = headGrpSize * qSeqLen * float (seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
908907 * nbKHeads * batchSize / (ms * 1E-3F ) * 1E-12F ;
909908#if SKIP_SOFTMAX_ATTN && SKIP_SOFTMAX_ATTN_BLOCK_STATS
910- float const topsWithSkip = headGrpSize * qSeqLen * float (seqLen) * (validElemsPerKHead + validElemsPerVHead) * 2
911- * nbKHeads * batchSize / (ms * 1E-3F ) * 1E-12F ;
912- printf (" kernel skipped_block_count: %d/%d (%.2f%%)\n " , kernel_skipped_block_count[0 ],
913- kernel_total_block_count[0 ],
914- kernel_total_block_count[0 ] == 0 ? 0 .0f
915- : 100 .0f * kernel_skipped_block_count[0 ] / kernel_total_block_count[0 ]);
916- printf (" dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n " , dramSolRatioWithSkip * 100 , ms, topsWithSkip);
909+ printf (" kernel skippedBlockCount: %d/%d (%.2f%%)\n " , kernelSkippedBlockCount[0 ], kernelTotalBlockCount[0 ],
910+ kernelTotalBlockCount[0 ] == 0 ? 0 .0f : 100 .0f * kernelSkippedBlockCount[0 ] / kernelTotalBlockCount[0 ]);
911+ printf (" dramSolRatioWithSkip: %f%% (%f ms, TOPS = %f)\n " , dramSolRatioWithSkip * 100 , ms, tops);
917912#else
918913 printf (" dramSolRatio: %f%% (%f ms, TOPS = %f)\n " , dramSolRatio * 100 , ms, tops);
919914#endif
@@ -1138,8 +1133,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
11381133 {
11391134 refOutput = refFlashAttention<CacheElem, 64 >(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq ,
11401135 vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0 ], xScale, slidingWinSize, refAttentionSinks,
1141- skip_softmax_threshold_scale_factor, &skipped_block_count, &total_block_count,
1142- multiBlockNum);
1136+ skipSoftmaxThresholdScaleFactor, &skippedBlockCount, &totalBlockCount, multiBlockNum);
11431137 // refOutput = refAttention<CacheElem>(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq,
11441138 // vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize);
11451139 }
@@ -1187,13 +1181,11 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck,
11871181 }
11881182 }
11891183#if SKIP_SOFTMAX_ATTN
1190- printf (" host skipped_block_count : %d/%d (%.2f%%)\n " , skipped_block_count, total_block_count ,
1191- total_block_count == 0 ? 0 .0f : 100 .0f * skipped_block_count / total_block_count );
1184+ printf (" host skippedBlockCount : %d/%d (%.2f%%)\n " , skippedBlockCount, totalBlockCount ,
1185+ totalBlockCount == 0 ? 0 .0f : 100 .0f * skippedBlockCount / totalBlockCount );
11921186#if SKIP_SOFTMAX_ATTN_BLOCK_STATS
1193- printf (" kernel skipped_block_count: %d/%d (%.2f%%)\n " , kernel_skipped_block_count[0 ],
1194- kernel_total_block_count[0 ],
1195- kernel_total_block_count[0 ] == 0 ? 0 .0f
1196- : 100 .0f * kernel_skipped_block_count[0 ] / kernel_total_block_count[0 ]);
1187+ printf (" kernel skippedBlockCount: %d/%d (%.2f%%)\n " , kernelSkippedBlockCount[0 ], kernelTotalBlockCount[0 ],
1188+ kernelTotalBlockCount[0 ] == 0 ? 0 .0f : 100 .0f * kernelSkippedBlockCount[0 ] / kernelTotalBlockCount[0 ]);
11971189#endif
11981190#endif
11991191 if (saveData)
0 commit comments