Skip to content

Commit a5f9585

Browse files
committed
update trtllm-gen to dd8b
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 0e68514 commit a5f9585

File tree

6 files changed

+73
-36
lines changed

6 files changed

+73
-36
lines changed

flashinfer/jit/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,12 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec:
233233
],
234234
extra_cuda_cflags=[
235235
"-DTLLM_GEN_EXPORT_INTERFACE",
236+
"-DTLLM_GEN_EXPORT_FLASHINFER",
236237
"-DTLLM_ENABLE_CUDA",
237238
"-DENABLE_BF16",
238239
"-DENABLE_FP8",
239240
"-DENABLE_FP4",
240-
f'-DTLLM_GEN_BMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"',
241+
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_BMM}\\"',
241242
]
242243
+ nvcc_flags,
243244
extra_include_paths=[

flashinfer/jit/gemm/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
381381
],
382382
extra_cuda_cflags=[
383383
"-DTLLM_GEN_EXPORT_INTERFACE",
384+
"-DTLLM_GEN_EXPORT_FLASHINFER",
384385
"-DTLLM_ENABLE_CUDA",
385386
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
386387
]
@@ -531,6 +532,7 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec:
531532
],
532533
extra_cuda_cflags=[
533534
"-DTLLM_GEN_EXPORT_INTERFACE",
535+
"-DTLLM_GEN_EXPORT_FLASHINFER",
534536
"-DTLLM_ENABLE_CUDA",
535537
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
536538
]

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ enum class RouteImpl {
3131
// Use LDGSTS to do the routing
3232
Ldgsts = 1,
3333
// Use UTMALDG.GATHER4 to do the routing
34-
Tma = 2
34+
Tma = 2,
35+
// Use LDG+STS to do the routing
36+
LdgPlusSts = 3
3537
};
3638

3739
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -48,6 +50,10 @@ inline bool doesRouteImplUseTma(RouteImpl mode) { return (mode == RouteImpl::Tma
4850

4951
////////////////////////////////////////////////////////////////////////////////////////////////////
5052

53+
inline bool doesRouteImplUseLdgPlusSts(RouteImpl mode) { return (mode == RouteImpl::LdgPlusSts); }
54+
55+
////////////////////////////////////////////////////////////////////////////////////////////////////
56+
5157
} // namespace batchedGemm
5258

5359
////////////////////////////////////////////////////////////////////////////////////////////////////

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,22 @@
2424
#include "trtllm/gen/CudaKernelLauncher.h"
2525

2626
#ifdef TLLM_GEN_EXPORT_INTERFACE
27+
#ifdef TLLM_GEN_EXPORT_FLASHINFER
2728
#include "flashinferMetaInfo.h"
28-
#endif // TLLM_GEN_EXPORT_INTERFACE
29-
30-
#include "flashinfer/trtllm/common.h"
31-
#ifdef TLLM_GEN_BMM_CUBIN_PATH
32-
static const std::string tllm_gen_bmm_cubin_path = std::string(TLLM_GEN_BMM_CUBIN_PATH);
3329
#else
34-
static_assert(false, "TLLM_GEN_BMM_CUBIN_PATH macro is not defined when compiling");
35-
#endif
36-
37-
namespace flashinfer::trtllm_cubin_loader {
38-
std::string getCubin(const std::string& kernelName, const std::string& sha256);
39-
}
30+
#include "KernelMetaInfo.h"
31+
#endif // TLLM_GEN_EXPORT_FLASHINFER
32+
#endif // TLLM_GEN_EXPORT_INTERFACE
4033

4134
namespace batchedGemm {
4235

4336
namespace batchedGemm {
4437

45-
//////////////////////////////////////////////////////////////////////////////////////////////////
38+
////////////////////////////////////////////////////////////////////////////////////////////////////
4639
//
4740
// BatchedGemmData
4841
//
49-
//////////////////////////////////////////////////////////////////////////////////////////////////
42+
////////////////////////////////////////////////////////////////////////////////////////////////////
5043

5144
struct BatchedGemmData {
5245
struct ProblemDimensions {
@@ -448,11 +441,11 @@ struct BatchedGemmData {
448441
OutputBuffers mOutputBuffers;
449442
};
450443

451-
//////////////////////////////////////////////////////////////////////////////////////////////////
444+
////////////////////////////////////////////////////////////////////////////////////////////////////
452445
//
453446
// BatchedGemmInterface
454447
//
455-
//////////////////////////////////////////////////////////////////////////////////////////////////
448+
////////////////////////////////////////////////////////////////////////////////////////////////////
456449

457450
class BatchedGemmInterface {
458451
public:
@@ -530,18 +523,12 @@ class BatchedGemmInterface {
530523
if (config.mData == nullptr) {
531524
batchedGemmConfig = generateAndCompileKernel(batchedGemmConfig);
532525
}
526+
TLLM_CHECK_ERROR(batchedGemmConfig.mCudaRunner != nullptr, "CudaRunner is not set");
527+
batchedGemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid,
528+
/* cluster */ {},
529+
/* instanceId */ batchedGemmConfig.mInstanceIdx);
530+
return 0;
533531
#endif
534-
auto fiModuleLoadData = [&](CUmodule* module) {
535-
const std::string sha256 = config.mHash ? config.mHash : "";
536-
std::string fname_cubin = config.mFunctionName;
537-
if (!fname_cubin.empty()) {
538-
fname_cubin[0] =
539-
static_cast<char>(std::toupper(static_cast<unsigned char>(fname_cubin[0])));
540-
}
541-
fname_cubin = tllm_gen_bmm_cubin_path + "/" + fname_cubin + ".cubin";
542-
std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256);
543-
cuErrCheck(cuModuleLoadData(module, cubin.c_str()));
544-
};
545532

546533
CUmodule cuModule;
547534
CUfunction cuFunction;
@@ -567,12 +554,12 @@ class BatchedGemmInterface {
567554
if (module != moduleCacheRef.end()) {
568555
cuFunction = std::get<1>(module->second);
569556
} else {
570-
fiModuleLoadData(&cuModule);
557+
gemm::loadCubinData(&cuModule, batchedGemmConfig);
571558
cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
572559
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
573560
}
574561
} else {
575-
fiModuleLoadData(&cuModule);
562+
gemm::loadCubinData(&cuModule, batchedGemmConfig);
576563
cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
577564
}
578565

@@ -808,10 +795,10 @@ class BatchedGemmInterface {
808795
int32_t mNumRotations;
809796
};
810797

811-
//////////////////////////////////////////////////////////////////////////////////////////////////
798+
////////////////////////////////////////////////////////////////////////////////////////////////////
812799

813800
} // namespace batchedGemm
814801

815-
//////////////////////////////////////////////////////////////////////////////////////////////////
802+
////////////////////////////////////////////////////////////////////////////////////////////////////
816803

817804
} // namespace batchedGemm

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,10 @@ inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool i
260260

261261
if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() != options.mRouteImpl) {
262262
TLLM_CHECK_ERROR(
263-
options.mRouteSfsImpl.value() == RouteImpl::Ldgsts && options.mRouteImpl == RouteImpl::Tma,
264-
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts, when RouteImpl is Tma");
263+
(options.mRouteSfsImpl.value() == RouteImpl::Ldgsts ||
264+
options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts) &&
265+
options.mRouteImpl == RouteImpl::Tma,
266+
"RouteSfsImpl must be equal to RouteImpl, or Ldgsts/LdgPlusSts, when RouteImpl is Tma");
265267
} else if (!options.mRouteSfsImpl.has_value()) {
266268
if (updateOptions) {
267269
options.mRouteSfsImpl = options.mRouteImpl;
@@ -271,6 +273,15 @@ inline bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool i
271273
}
272274
}
273275

276+
TLLM_CHECK_ERROR(options.mRouteImpl != RouteImpl::LdgPlusSts,
277+
"LdgPlusSts does not support routing the tokens");
278+
279+
if (options.mRouteSfsImpl.has_value() && options.mRouteSfsImpl.value() == RouteImpl::LdgPlusSts) {
280+
TLLM_CHECK_ERROR(!batchM, "LdgPlusSts only supports batch N");
281+
TLLM_CHECK_ERROR(options.mTileK <= 512 && options.mTileK >= 128,
282+
"LdgPlusSts only supports 128 <= tileK <= 512");
283+
}
284+
274285
if (batchM) {
275286
if (options.mDtypeA == tg::Dtype::MxE2m1 && options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) {
276287
TLLM_CHECK_ERROR(doesRouteImplUseNoRoute(options.mRouteImpl),

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
#include "trtllm/gen/CudaRunner.h"
3131
#include "trtllm/gen/GenCtx.h"
3232
#else
33+
#ifdef TLLM_GEN_EXPORT_FLASHINFER
34+
#include <string>
35+
namespace flashinfer::trtllm_cubin_loader {
36+
std::string getCubin(const std::string& kernelName, const std::string& sha256);
37+
}
38+
#endif // TLLM_GEN_EXPORT_FLASHINFER
3339
#include <iostream>
40+
namespace batchedGemm {
3441

3542
template <typename T>
3643
void printArgs(T arg) {
@@ -72,8 +79,6 @@ void printArgs(T first, Args... args) {
7279

7380
#endif // TLLM_GEN_EXPORT_INTERFACE
7481

75-
namespace batchedGemm {
76-
7782
namespace trtllm {
7883
namespace gen {
7984
class CudaRunner;
@@ -1471,6 +1476,31 @@ inline bool getKernelDoesScaleC(tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dt
14711476

14721477
////////////////////////////////////////////////////////////////////////////////////////////////////
14731478

1479+
template <typename Config>
1480+
inline CUresult loadCubinData(CUmodule* module, Config const& config) {
1481+
// Trtllm links the cubin into the executable while Flashinfer loads the cubin from storage.
1482+
#ifdef TLLM_GEN_EXPORT_FLASHINFER
1483+
#ifdef TLLM_GEN_GEMM_CUBIN_PATH
1484+
static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
1485+
const std::string sha256 = config.mHash ? config.mHash : "";
1486+
std::string fileName = config.mFunctionName;
1487+
if (!fileName.empty()) {
1488+
fileName[0] = static_cast<char>(std::toupper(static_cast<unsigned char>(fileName[0])));
1489+
}
1490+
const std::string& data = flashinfer::trtllm_cubin_loader::getCubin(
1491+
tllm_gen_gemm_cubin_path + "/" + fileName + ".cubin", sha256);
1492+
CUresult result = cuModuleLoadData(module, data.c_str());
1493+
#else
1494+
static_assert(false, "TLLM_GEN_GEMM_CUBIN_PATH macro is not defined when compiling");
1495+
#endif // TLLM_GEN_GEMM_CUBIN_PATH
1496+
#else
1497+
CUresult result = cuModuleLoadData(module, config.mData);
1498+
#endif // TLLM_GEN_EXPORT_FLASHINFER
1499+
return result;
1500+
}
1501+
1502+
////////////////////////////////////////////////////////////////////////////////////////////////////
1503+
14741504
} // namespace gemm
14751505

14761506
#ifdef TLLM_GEN_EXPORT_INTERFACE

0 commit comments

Comments
 (0)