Skip to content

Commit 0e68514

Browse files
committed
Add WAR
1 parent 29c23cd commit 0e68514

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,17 +239,23 @@ void TrtllmGenBatchedGemmRunner::run(
239239
int32_t multiProcessorCount;
240240
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
241241

242+
// FIXME: this is a WAR to solve the perf regression and should be removed once
243+
// trtllm-gen fixes the issue.
244+
auto myConfig = config;
245+
myConfig.mOptions.mValidK = k;
246+
myConfig.mOptions.mValidN = gemmData.mProblemDimensions.mN;
247+
myConfig.mOptions.mValidM = gemmData.mProblemDimensions.mM;
242248
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
243-
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
249+
bmm.runInitBeforeWorldSync(myConfig, gemmData, static_cast<void*>(stream));
244250

245-
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream),
251+
auto const err = bmm.run(myConfig, workspace, gemmData, static_cast<void*>(stream),
246252
multiProcessorCount, enable_pdl, globalTrtllmGenBatchedGemmModuleCache);
247253

248254
FLASHINFER_CHECK(err == 0,
249255
"Error occurred when running GEMM!"
250256
" (numBatches: ",
251-
numBatches, ", GemmMNK: ", m, " ", n, " ", k, ", Kernel: ", config.mFunctionName,
252-
")");
257+
numBatches, ", GemmMNK: ", m, " ", n, " ", k,
258+
", Kernel: ", myConfig.mFunctionName, ")");
253259
}
254260

255261
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k,
@@ -387,8 +393,13 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
387393
// Filter out invalid configs.
388394
std::vector<int64_t> validConfigIndices;
389395
for (auto const& configIndex : prioritizedIndices) {
390-
auto const& config = configs[configIndex];
391-
auto isValidConfig = bmm.isValidConfig(config, gemmData);
396+
// FIXME: this is a WAR to solve the perf regression and should be removed once
397+
// trtllm-gen fixes the issue.
398+
auto myConfig = configs[configIndex];
399+
myConfig.mOptions.mValidK = k;
400+
myConfig.mOptions.mValidN = gemmData.mProblemDimensions.mN;
401+
myConfig.mOptions.mValidM = gemmData.mProblemDimensions.mM;
402+
auto isValidConfig = bmm.isValidConfig(myConfig, gemmData);
392403
if (isValidConfig) {
393404
validConfigIndices.push_back(configIndex);
394405
}

0 commit comments

Comments
 (0)