@@ -144,6 +144,10 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(
144144 gemmData.mProblemDimensions .mWorldSize = 1 ;
145145 gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
146146
147+ gemmData.mProblemDimensions .mValidM = gemmData.mProblemDimensions .mM ;
148+ gemmData.mProblemDimensions .mValidN = gemmData.mProblemDimensions .mN ;
149+ gemmData.mProblemDimensions .mValidK = gemmData.mProblemDimensions .mK ;
150+
147151 auto bmm = BatchedGemmInterface ();
148152
149153 auto const configs = bmm.getBatchedGemmConfigs ();
@@ -239,23 +243,21 @@ void TrtllmGenBatchedGemmRunner::run(
239243 int32_t multiProcessorCount;
240244 cudaDeviceGetAttribute (&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
241245
242- // FIXME: this is a WAR to solve the perf regression and should be removed once
243- // trtllm-gen fixes the issue.
244- auto myConfig = config;
245- myConfig.mOptions .mValidK = k;
246- myConfig.mOptions .mValidN = gemmData.mProblemDimensions .mN ;
247- myConfig.mOptions .mValidM = gemmData.mProblemDimensions .mM ;
246+ gemmData.mProblemDimensions .mValidM = gemmData.mProblemDimensions .mM ;
247+ gemmData.mProblemDimensions .mValidN = gemmData.mProblemDimensions .mN ;
248+ gemmData.mProblemDimensions .mValidK = gemmData.mProblemDimensions .mK ;
249+
248250 // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
249- bmm.runInitBeforeWorldSync (myConfig , gemmData, static_cast <void *>(stream));
251+ bmm.runInitBeforeWorldSync (config , gemmData, static_cast <void *>(stream));
250252
251- auto const err = bmm.run (myConfig , workspace, gemmData, static_cast <void *>(stream),
253+ auto const err = bmm.run (config , workspace, gemmData, static_cast <void *>(stream),
252254 multiProcessorCount, enable_pdl, globalTrtllmGenBatchedGemmModuleCache);
253255
254256 FLASHINFER_CHECK (err == 0 ,
255257 " Error occurred when running GEMM!"
256258 " (numBatches: " ,
257- numBatches, " , GemmMNK: " , m, " " , n, " " , k,
258- " , Kernel: " , myConfig. mFunctionName , " )" );
259+ numBatches, " , GemmMNK: " , m, " " , n, " " , k, " , Kernel: " , config. mFunctionName ,
260+ " )" );
259261}
260262
261263void TrtllmGenBatchedGemmRunner::run (int32_t m, int32_t n, int32_t k,
@@ -333,6 +335,10 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
333335 gemmData.mProblemDimensions .mWorldSize = 1 ;
334336 gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
335337
338+ gemmData.mProblemDimensions .mValidM = gemmData.mProblemDimensions .mM ;
339+ gemmData.mProblemDimensions .mValidN = gemmData.mProblemDimensions .mN ;
340+ gemmData.mProblemDimensions .mValidK = gemmData.mProblemDimensions .mK ;
341+
336342 auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) {
337343 auto const & optionsA = configs[idx0].mOptions ;
338344 auto const & optionsB = configs[idx1].mOptions ;
@@ -393,13 +399,7 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
393399 // Filter out invalid configs.
394400 std::vector<int64_t > validConfigIndices;
395401 for (auto const & configIndex : prioritizedIndices) {
396- // FIXME: this is a WAR to solve the perf regression and should be removed once
397- // trtllm-gen fixes the issue.
398- auto myConfig = configs[configIndex];
399- myConfig.mOptions .mValidK = k;
400- myConfig.mOptions .mValidN = gemmData.mProblemDimensions .mN ;
401- myConfig.mOptions .mValidM = gemmData.mProblemDimensions .mM ;
402- auto isValidConfig = bmm.isValidConfig (myConfig, gemmData);
402+ auto isValidConfig = bmm.isValidConfig (configs[configIndex], gemmData);
403403 if (isValidConfig) {
404404 validConfigIndices.push_back (configIndex);
405405 }
0 commit comments