@@ -239,17 +239,23 @@ void TrtllmGenBatchedGemmRunner::run(
239239 int32_t multiProcessorCount;
240240 cudaDeviceGetAttribute (&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
241241
242+ // FIXME: this is a WAR to solve the perf regression and should be removed once
243+ // trtllm-gen fixes the issue.
244+ auto myConfig = config;
245+ myConfig.mOptions .mValidK = k;
246+ myConfig.mOptions .mValidN = gemmData.mProblemDimensions .mN ;
247+ myConfig.mOptions .mValidM = gemmData.mProblemDimensions .mM ;
242248 // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
243- bmm.runInitBeforeWorldSync (config , gemmData, static_cast <void *>(stream));
249+ bmm.runInitBeforeWorldSync (myConfig , gemmData, static_cast <void *>(stream));
244250
245- auto const err = bmm.run (config , workspace, gemmData, static_cast <void *>(stream),
251+ auto const err = bmm.run (myConfig , workspace, gemmData, static_cast <void *>(stream),
246252 multiProcessorCount, enable_pdl, globalTrtllmGenBatchedGemmModuleCache);
247253
248254 FLASHINFER_CHECK (err == 0 ,
249255 " Error occurred when running GEMM!"
250256 " (numBatches: " ,
251- numBatches, " , GemmMNK: " , m, " " , n, " " , k, " , Kernel: " , config. mFunctionName ,
252- " )" );
257+ numBatches, " , GemmMNK: " , m, " " , n, " " , k,
258+ " , Kernel: " , myConfig. mFunctionName , " )" );
253259}
254260
255261void TrtllmGenBatchedGemmRunner::run (int32_t m, int32_t n, int32_t k,
@@ -387,8 +393,13 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
387393 // Filter out invalid configs.
388394 std::vector<int64_t > validConfigIndices;
389395 for (auto const & configIndex : prioritizedIndices) {
390- auto const & config = configs[configIndex];
391- auto isValidConfig = bmm.isValidConfig (config, gemmData);
396+ // FIXME: this is a WAR to solve the perf regression and should be removed once
397+ // trtllm-gen fixes the issue.
398+ auto myConfig = configs[configIndex];
399+ myConfig.mOptions .mValidK = k;
400+ myConfig.mOptions .mValidN = gemmData.mProblemDimensions .mN ;
401+ myConfig.mOptions .mValidM = gemmData.mProblemDimensions .mM ;
402+ auto isValidConfig = bmm.isValidConfig (myConfig, gemmData);
392403 if (isValidConfig) {
393404 validConfigIndices.push_back (configIndex);
394405 }
0 commit comments