2424#include " flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h"
2525#include " flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
2626#include " flashinfer/trtllm/common.h"
27+ #include " tensorrt_llm/common/cudaUtils.h"
2728#include " tensorrt_llm/common/envUtils.h"
2829
2930namespace tensorrt_llm {
@@ -306,6 +307,8 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
306307 auto const bmm = BatchedGemmInterface ();
307308 auto const configs = bmm.getBatchedGemmConfigs ();
308309
310+ int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount ();
311+
309312 BatchedGemmData gemmData;
310313 // Dims
311314 gemmData.mProblemDimensions .mNumBatches = numBatches;
@@ -322,67 +325,57 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
322325 gemmData.mProblemDimensions .mWorldSize = 1 ;
323326 gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
324327
325- // Tier 0: K < tileK, prefer higher efficiency.
326- auto cmpTier0 = [&configs, &gemmData](int64_t idx0, int64_t idx1) {
328+ auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) {
327329 auto const & optionsA = configs[idx0].mOptions ;
328330 auto const & optionsB = configs[idx1].mOptions ;
329331 int32_t sizeK = gemmData.mProblemDimensions .mK ;
330- // Both waste computation, prefer higher efficiency.
331- if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK ) {
332- double eff_a = (double )sizeK / optionsA.mTileK ;
333- double eff_b = (double )sizeK / optionsB.mTileK ;
334- return eff_a > eff_b;
335- }
336- // If either can be utilized, sort by tileK.
337- else {
338- return optionsA.mTileK > optionsB.mTileK ;
332+
333+ // Tier 0: K < tileK, prefer higher efficiency.
334+ if (optionsA.mTileK != optionsB.mTileK ) {
335+ // Both waste computation, prefer higher efficiency.
336+ if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK ) {
337+ double eff_a = (double )sizeK / optionsA.mTileK ;
338+ double eff_b = (double )sizeK / optionsB.mTileK ;
339+ return eff_a > eff_b;
340+ }
341+ // If either can be utilized, sort by tileK.
342+ else {
343+ return optionsA.mTileK > optionsB.mTileK ;
344+ }
339345 }
340- };
341- // Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
342- auto cmpTier1 = [&configs](int64_t idx0, int64_t idx1) {
343- auto const & optionsA = configs[idx0].mOptions ;
344- auto const & optionsB = configs[idx1].mOptions ;
345- if (optionsA.mTileK == optionsB.mTileK ) {
346+
347+ // Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
348+ if (optionsA.mUseUnrollLoop2xForMma != optionsB.mUseUnrollLoop2xForMma ) {
346349 return optionsA.mUseUnrollLoop2xForMma ;
347350 }
348- return false ;
349- };
350- // Tier 2+: When previous comparators are the same, prefer higher tileM.
351- auto cmpTier2 = [&configs](int64_t idx0, int64_t idx1) {
352- auto const & optionsA = configs[idx0].mOptions ;
353- auto const & optionsB = configs[idx1].mOptions ;
354- if (optionsA.mTileK == optionsB.mTileK &&
355- optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma ) {
351+
352+ // Tier 2+: When previous comparators are the same, prefer higher tileM.
353+ if (optionsA.mTileM != optionsB.mTileM ) {
356354 return optionsA.mTileM > optionsB.mTileM ;
357355 }
358- return false ;
359- };
360- // Tier 2+: When previous comparators are the same, and when number of estimated CTAs is on the
361- // larger side, prefer persistent tile scheduler. The threshold is hardcoded as >148 CTAs at the
362- // moment.
363- auto cmpTier3 = [&configs, &gemmData](int64_t idx0, int64_t idx1) {
364- int32_t sizeM = gemmData.mProblemDimensions .mM ;
365- int32_t sizeN = gemmData.mProblemDimensions .mN ;
366- auto const & optionsA = configs[idx0].mOptions ;
367- auto const & optionsB = configs[idx1].mOptions ;
368- if (optionsA.mTileK == optionsB.mTileK &&
369- optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma &&
370- optionsA.mTileM == optionsB.mTileM ) {
371- int64_t numTilesM = batchedGemm::gemm::divUp (sizeM, optionsA.mTileM );
372- int64_t numTilesN = batchedGemm::gemm::divUp (sizeN, optionsA.mTileN );
373- if (numTilesM * numTilesN > 148 ) {
356+
357+ // Tier 2+: When previous comparators are the same, prefer higher tileN.
358+ if (optionsA.mTileN != optionsB.mTileN ) {
359+ return optionsA.mTileN > optionsB.mTileN ;
360+ }
361+
362+ // Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on
363+ // the larger side, prefer persistent tile scheduler.
364+ if (optionsA.mTileScheduler != optionsB.mTileScheduler ) {
365+ auto options = bmm.getOptionsFromConfigAndData (configs[idx0], gemmData);
366+ auto numCtas = bmm.getNumCtas (options, gemmData.mProblemDimensions .mMaxNumCtasInTokenDim );
367+ if (numCtas > multiProcessorCount) {
374368 return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
369+ } else {
370+ return optionsB.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
375371 }
376372 }
373+
377374 return false ;
378375 };
379-
380376 // Sort configs by options.
381377 std::vector<int64_t > sortedIndices = mPassingConfigIndices ;
382- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier0);
383- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier1);
384- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier2);
385- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier3);
378+ std::sort (sortedIndices.begin (), sortedIndices.end (), cmpFunc);
386379
387380 // Special rules for corner cases, if applicable.
388381 std::vector<int64_t > prioritizedIndices =
0 commit comments