Skip to content

Commit a1235ee

Browse files
[feat] Adds optional module cache for TRT-LLM Gen Gemm interfaces (#5743)
Signed-off-by: David Clark <[email protected]> Co-authored-by: Nikita Korobov <[email protected]>
1 parent 1191555 commit a1235ee

File tree

6 files changed

+151
-16
lines changed

6 files changed

+151
-16
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ using namespace batchedGemm::batchedGemm;
3030
using namespace batchedGemm::gemm;
3131
using namespace batchedGemm::trtllm::gen;
3232

33+
static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache;
34+
3335
std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices,
3436
batchedGemm::batchedGemm::BatchedGemmConfig const* configs)
3537
{
@@ -295,7 +297,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
295297
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
296298
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
297299

298-
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount);
300+
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
301+
globalTrtllmGenBatchedGemmModuleCache);
299302

300303
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
301304
}

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include <numeric>
20+
#include <optional>
2021

2122
#include "BatchedGemmOptions.h"
2223
#include "KernelParams.h"
@@ -392,12 +393,14 @@ struct BatchedGemmData
392393
class BatchedGemmInterface
393394
{
394395
public:
396+
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
397+
395398
BatchedGemmInterface() {}
396399

397400
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
398401
// Provided config must be validated with isValidConfig before the call.
399402
int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, void* cudaStream,
400-
int32_t multiProcessorCount);
403+
int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt);
401404

402405
// Initializes the buffers before the world sync. Must be called before run.
403406
int32_t runInitBeforeWorldSync(
@@ -579,9 +582,9 @@ std::vector<size_t> BatchedGemmInterface::getWorkspaceSizesInBytes(
579582
}
580583

581584
////////////////////////////////////////////////////////////////////////////////////////////////////
582-
583585
int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace,
584-
BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */)
586+
BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */,
587+
std::optional<std::reference_wrapper<ModuleCache>> moduleCache)
585588
{
586589
// Get options from config and data.
587590
auto options = getOptionsFromConfigAndData(config, batchedGemmData);
@@ -652,8 +655,42 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
652655
#ifdef TLLM_GEN_EXPORT_INTERFACE
653656
CUmodule cuModule;
654657
CUfunction cuFunction;
655-
cuModuleLoadData(&cuModule, config.mData);
656-
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
658+
if (moduleCache.has_value())
659+
{
660+
ModuleCache& moduleCacheRef = moduleCache.value().get();
661+
662+
// Modules are associated with a specific context so include the ctxId in the key
663+
CUcontext ctx;
664+
unsigned long long ctxId;
665+
cuCtxGetCurrent(&ctx);
666+
cuCtxGetId(ctx, &ctxId);
667+
668+
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
669+
// representation.
670+
std::string const ctxName
671+
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
672+
std::string const funcName = std::string(config.mFunctionName);
673+
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
674+
auto const moduleKey = ctxName + funcName;
675+
auto module = moduleCacheRef.find(moduleKey);
676+
677+
// Check if module exists in cache. Otherwise, load it
678+
if (module != moduleCacheRef.end())
679+
{
680+
cuFunction = std::get<1>(module->second);
681+
}
682+
else
683+
{
684+
cuModuleLoadData(&cuModule, config.mData);
685+
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
686+
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
687+
}
688+
}
689+
else
690+
{
691+
cuModuleLoadData(&cuModule, config.mData);
692+
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
693+
}
657694

658695
// Prepare the grid/block.
659696
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
@@ -673,6 +710,11 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
673710
{
674711
return -1;
675712
}
713+
// If a module cache has not been given, unload the module to avoid overflow
714+
if (!moduleCache.has_value())
715+
{
716+
cuModuleUnload(cuModule);
717+
}
676718
#else
677719
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
678720
#endif

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ namespace kernels
3030
namespace tg = gemm::trtllm::gen;
3131
using namespace gemm::gemm;
3232

33+
static GemmInterface::ModuleCache globalTrtllmGenGemmModuleCache;
34+
3335
TrtllmGenGemmRunner::TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& options_)
3436
: mOptions(options_)
3537
{
@@ -111,7 +113,8 @@ void TrtllmGenGemmRunner::run(int32_t m, int32_t n, int32_t k, void const* a, fl
111113
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
112114
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
113115

114-
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount);
116+
auto const err = gemm.run(
117+
config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount, globalTrtllmGenGemmModuleCache);
115118

116119
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
117120
}

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/trtllmGen_gemm_export/GemmInterface.h

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,15 @@ struct GemmData
222222
class GemmInterface
223223
{
224224
public:
225+
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
226+
225227
GemmInterface() {}
226228

227229
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
228230
// Provided config must be validated with isValidConfig before the call.
229231
int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream,
230-
int32_t multiProcessorCount) const;
232+
int32_t multiProcessorCount,
233+
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) const;
231234

232235
// Initializes the buffers before the world sync. Must be called before run.
233236
int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, void* cudaStream) const;
@@ -384,7 +387,7 @@ bool GemmInterface::isValidConfig(GemmConfig const& config, GemmData const& data
384387
////////////////////////////////////////////////////////////////////////////////////////////////////
385388

386389
int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream,
387-
int32_t multiProcessorCount) const
390+
int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
388391
{
389392
// Get options from config and data.
390393
auto options = getOptionsFromConfigAndData(config, data);
@@ -439,8 +442,42 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
439442
#ifdef TLLM_GEN_EXPORT_INTERFACE
440443
CUmodule cuModule;
441444
CUfunction cuFunction;
442-
cuModuleLoadData(&cuModule, config.mData);
443-
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
445+
if (moduleCache.has_value())
446+
{
447+
ModuleCache& moduleCacheRef = moduleCache.value().get();
448+
449+
// Modules are associated with a specific context so include the ctxId in the key
450+
CUcontext ctx;
451+
unsigned long long ctxId;
452+
cuCtxGetCurrent(&ctx);
453+
cuCtxGetId(ctx, &ctxId);
454+
455+
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
456+
// representation.
457+
std::string const ctxName
458+
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
459+
std::string const funcName = std::string(config.mFunctionName);
460+
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
461+
auto const moduleKey = ctxName + funcName;
462+
auto module = moduleCacheRef.find(moduleKey);
463+
464+
// Check if module exists in cache. Otherwise, load it
465+
if (module != moduleCacheRef.end())
466+
{
467+
cuFunction = std::get<1>(module->second);
468+
}
469+
else
470+
{
471+
cuModuleLoadData(&cuModule, config.mData);
472+
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
473+
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
474+
}
475+
}
476+
else
477+
{
478+
cuModuleLoadData(&cuModule, config.mData);
479+
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
480+
}
444481

445482
// Prepare the grid/block.
446483
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
@@ -460,6 +497,11 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
460497
{
461498
return -1;
462499
}
500+
// If a module cache has not been given, unload the module to avoid leaking
501+
if (!moduleCache.has_value())
502+
{
503+
cuModuleUnload(cuModule);
504+
}
463505
#else
464506
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
465507
#endif

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/KernelRunner.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace tensorrt_llm
2727
namespace kernels
2828
{
2929

30+
static gemmGatedAct::GemmGatedActInterface::ModuleCache globalTrtllmGenGemmGatedActModuleCache;
31+
3032
TrtllmGenGemmGatedActRunner::TrtllmGenGemmGatedActRunner(TrtllmGenGemmGatedActRunnerOptions const& options_)
3133
: mOptions(options_)
3234
{
@@ -104,7 +106,8 @@ void TrtllmGenGemmGatedActRunner::run(int32_t m, int32_t n, int32_t k, void cons
104106
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
105107
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
106108

107-
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount);
109+
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
110+
globalTrtllmGenGemmGatedActModuleCache);
108111

109112
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
110113
}

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/trtllmGen_gatedAct_export/GemmGatedActInterface.h

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,15 @@ struct GemmGatedActData
183183
class GemmGatedActInterface
184184
{
185185
public:
186+
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
187+
186188
GemmGatedActInterface() {}
187189

188190
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
189191
// Provided config must be validated with isValidConfig before the call.
190192
int32_t run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, void* cudaStream,
191-
int32_t multiProcessorCount) const;
193+
int32_t multiProcessorCount,
194+
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) const;
192195

193196
// Initializes the buffers before the world sync. Must be called before run.
194197
int32_t runInitBeforeWorldSync(
@@ -340,7 +343,7 @@ bool GemmGatedActInterface::isValidConfig(GemmGatedActConfig const& config, Gemm
340343
////////////////////////////////////////////////////////////////////////////////////////////////////
341344

342345
int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data,
343-
void* cudaStream, int32_t multiProcessorCount) const
346+
void* cudaStream, int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
344347
{
345348
// Get options from config and data.
346349
auto options = getOptionsFromConfigAndData(config, data);
@@ -392,8 +395,42 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works
392395
#ifdef TLLM_GEN_EXPORT_INTERFACE
393396
CUmodule cuModule;
394397
CUfunction cuFunction;
395-
cuModuleLoadData(&cuModule, config.mData);
396-
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
398+
if (moduleCache.has_value())
399+
{
400+
ModuleCache& moduleCacheRef = moduleCache.value().get();
401+
402+
// Modules are associated with a specific context so include the ctxId in the key
403+
CUcontext ctx;
404+
unsigned long long ctxId;
405+
cuCtxGetCurrent(&ctx);
406+
cuCtxGetId(ctx, &ctxId);
407+
408+
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
409+
// representation.
410+
std::string const ctxName
411+
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
412+
std::string const funcName = std::string(config.mFunctionName);
413+
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
414+
auto const moduleKey = ctxName + funcName;
415+
auto module = moduleCacheRef.find(moduleKey);
416+
417+
// Check if module exists in cache. Otherwise, load it
418+
if (module != moduleCacheRef.end())
419+
{
420+
cuFunction = std::get<1>(module->second);
421+
}
422+
else
423+
{
424+
cuModuleLoadData(&cuModule, config.mData);
425+
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
426+
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
427+
}
428+
}
429+
else
430+
{
431+
cuModuleLoadData(&cuModule, config.mData);
432+
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
433+
}
397434

398435
// Prepare the grid/block.
399436
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
@@ -413,6 +450,11 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works
413450
{
414451
return -1;
415452
}
453+
// If a module cache has not been given, unload the module to avoid leaking
454+
if (!moduleCache.has_value())
455+
{
456+
cuModuleUnload(cuModule);
457+
}
416458
#else
417459
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
418460
#endif

0 commit comments

Comments
 (0)