Skip to content

Commit 7dc3a59

Browse files
committed
fix valid mnk
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 9b13c3b commit 7dc3a59

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

csrc/trtllm_batched_gemm_runner.cu

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(
144144
gemmData.mProblemDimensions.mWorldSize = 1;
145145
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
146146

147+
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
148+
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
149+
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
150+
147151
auto bmm = BatchedGemmInterface();
148152

149153
auto const configs = bmm.getBatchedGemmConfigs();
@@ -239,23 +243,21 @@ void TrtllmGenBatchedGemmRunner::run(
239243
int32_t multiProcessorCount;
240244
cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, device);
241245

242-
// FIXME: this is a WAR to solve the perf regression and should be removed once
243-
// trtllm-gen fixes the issue.
244-
auto myConfig = config;
245-
myConfig.mOptions.mValidK = k;
246-
myConfig.mOptions.mValidN = gemmData.mProblemDimensions.mN;
247-
myConfig.mOptions.mValidM = gemmData.mProblemDimensions.mM;
246+
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
247+
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
248+
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
249+
248250
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
249-
bmm.runInitBeforeWorldSync(myConfig, gemmData, static_cast<void*>(stream));
251+
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
250252

251-
auto const err = bmm.run(myConfig, workspace, gemmData, static_cast<void*>(stream),
253+
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream),
252254
multiProcessorCount, enable_pdl, globalTrtllmGenBatchedGemmModuleCache);
253255

254256
FLASHINFER_CHECK(err == 0,
255257
"Error occurred when running GEMM!"
256258
" (numBatches: ",
257-
numBatches, ", GemmMNK: ", m, " ", n, " ", k,
258-
", Kernel: ", myConfig.mFunctionName, ")");
259+
numBatches, ", GemmMNK: ", m, " ", n, " ", k, ", Kernel: ", config.mFunctionName,
260+
")");
259261
}
260262

261263
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k,
@@ -333,6 +335,10 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
333335
gemmData.mProblemDimensions.mWorldSize = 1;
334336
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
335337

338+
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
339+
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
340+
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
341+
336342
auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1) {
337343
auto const& optionsA = configs[idx0].mOptions;
338344
auto const& optionsB = configs[idx1].mOptions;
@@ -393,13 +399,7 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(
393399
// Filter out invalid configs.
394400
std::vector<int64_t> validConfigIndices;
395401
for (auto const& configIndex : prioritizedIndices) {
396-
// FIXME: this is a WAR to solve the perf regression and should be removed once
397-
// trtllm-gen fixes the issue.
398-
auto myConfig = configs[configIndex];
399-
myConfig.mOptions.mValidK = k;
400-
myConfig.mOptions.mValidN = gemmData.mProblemDimensions.mN;
401-
myConfig.mOptions.mValidM = gemmData.mProblemDimensions.mM;
402-
auto isValidConfig = bmm.isValidConfig(myConfig, gemmData);
402+
auto isValidConfig = bmm.isValidConfig(configs[configIndex], gemmData);
403403
if (isValidConfig) {
404404
validConfigIndices.push_back(configIndex);
405405
}

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,18 @@ struct BatchedGemmData {
7373
// The M dimension.
7474
// It is the total number of tokens if A is the activation matrix.
7575
// It is the total number of output channels if A is the weight matrix.
76+
// ValidM/N/K by default assumes to be full range of M/N/K respectively. If we pad M/N/K due to
77+
// alignment of other constraints, then we can specify ValidM/N/K to indicate the valid range.
7678
int32_t mM{0};
79+
int32_t mValidM{0};
7780
// The N dimension.
7881
// It is the total number of tokens if B is the activation matrix.
7982
// It is the total number of output channels if B is the weight matrix.
8083
int32_t mN{0};
84+
int32_t mValidN{0};
8185
// The K dimension. It is the hidden dimension of the input matrices.
8286
int32_t mK{0};
87+
int32_t mValidK{0};
8388
// The rank id of the current device in the multi-gpu space.
8489
int32_t mRank{0};
8590
// The number of devices in tensor-parallel group.
@@ -695,6 +700,9 @@ class BatchedGemmInterface {
695700
options.mM = data.mProblemDimensions.mM;
696701
options.mN = data.mProblemDimensions.mN;
697702
options.mK = data.mProblemDimensions.mK;
703+
options.mValidM = data.mProblemDimensions.mValidM;
704+
options.mValidN = data.mProblemDimensions.mValidN;
705+
options.mValidK = data.mProblemDimensions.mValidK;
698706
options.mBatchedM = data.mProblemDimensions.mBatchedM;
699707
options.mBatchedN = data.mProblemDimensions.mBatchedN;
700708
options.mBatchMode = data.mProblemDimensions.mBatchM ? BatchedGemmOptions::BatchMode::BatchM

0 commit comments

Comments
 (0)