Skip to content
Open
153 changes: 98 additions & 55 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType;
using tvm::ffi::Array;
using tvm::ffi::Optional;

enum class RoutingInputMode {
FromLogits, // Mode 1: Compute routing from logits
PackedPrecomputed, // Mode 2: Pre-computed with packed (score << 16 | id) format
UnpackedPrecomputed // Mode 3: Pre-computed with separate topk_ids and topk_weights
};

enum class Fp8QuantizationType {
NoneFp8,
DeepSeekFp8,
Expand Down Expand Up @@ -371,25 +377,33 @@ class FusedMoeLauncher {
check_routing();
prepare_routing();

cudaStream_t routing_stream = get_stream(hidden_states.device());

// Execute routing
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t routing_stream = get_stream(hidden_states.device());

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, static_cast<int*>(expert_indexes.data_ptr()),
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(),
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);
// This base class only supports Mode 1 (FromLogits) - compute routing from logits
constexpr RoutingInputMode routing_input_mode = RoutingInputMode::FromLogits;

// Mode 1: expertIds is nullptr, expertWeights is OUTPUT buffer for computed weights
int32_t* expert_ids_param = nullptr;
void* expert_weights_param = expert_weights.data_ptr();

routing_runner.run(args->routing_logits, args->routing_bias, args->num_tokens,
args->num_experts, args->top_k, args->n_group, args->topk_group,
args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, static_cast<int*>(expert_indexes.data_ptr()),
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_ids_param,
expert_weights_param, static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt,
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);

check_moe();
prepare_moe(moe_tactic);
Expand Down Expand Up @@ -1052,6 +1066,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
// When using pre-computed routing, pass nullptr as routing_logits to tell the
// routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes
// FP8 only supports Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), so expertIds is nullptr
routing_runner.run(
use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens,
args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset,
Expand All @@ -1060,8 +1075,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), workspace.expert_weights,
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()),
nullptr, // expertIds - FP8 doesn't support UnpackedPrecomputed mode
workspace.expert_weights, static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
Expand Down Expand Up @@ -1277,19 +1293,21 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
}

FP4BlockScaleLauncher(
Optional<TensorView> const& routing_logits, Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, Optional<TensorView> const& hidden_states_scale,
TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale,
Optional<TensorView> const& gemm1_bias, Optional<TensorView> const& gemm1_alpha,
Optional<TensorView> const& gemm1_beta, Optional<TensorView> const& gemm1_clamp_limit,
TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale,
Optional<TensorView> const& gemm2_bias, Optional<TensorView> const& output1_scales_scalar,
RoutingInputMode routing_input_mode, Optional<TensorView> const& routing_logits,
Optional<TensorView> const& routing_bias, TensorView const& hidden_states,
Optional<TensorView> const& hidden_states_scale, TensorView const& gemm1_weights,
TensorView const& gemm1_weights_scale, Optional<TensorView> const& gemm1_bias,
Optional<TensorView> const& gemm1_alpha, Optional<TensorView> const& gemm1_beta,
Optional<TensorView> const& gemm1_clamp_limit, TensorView const& gemm2_weights,
TensorView const& gemm2_weights_scale, Optional<TensorView> const& gemm2_bias,
Optional<TensorView> const& output1_scales_scalar,
Optional<TensorView> const& output1_scales_gate_scalar,
Optional<TensorView> const& output2_scales_scalar, TensorView const& expert_indices,
TensorView const& expert_weights)
Optional<TensorView> const& output2_scales_scalar, TensorView const& topk_ids,
TensorView const& topk_weights)
: FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights,
output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights,
output2_scales_scalar),
routing_input_mode_(routing_input_mode),
hidden_states_scale(hidden_states_scale),
gemm1_weights_scale(gemm1_weights_scale),
gemm1_bias(gemm1_bias),
Expand All @@ -1298,8 +1316,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
gemm1_clamp_limit(gemm1_clamp_limit),
gemm2_weights_scale(gemm2_weights_scale),
gemm2_bias(gemm2_bias),
expert_indices(expert_indices),
expert_weights(expert_weights) {}
topk_ids(topk_ids),
topk_weights(topk_weights) {}

void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
Expand Down Expand Up @@ -1360,9 +1378,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens.data_ptr());
workspace.total_max_padded_tokens = max_num_padded_tokens;
workspace.ProjUpTileN = tile_tokens_dim;
workspace.routing_expert_indexes =
static_cast<int*>(const_cast<void*>(expert_indices.data_ptr()));
workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
workspace.routing_expert_indexes = static_cast<int*>(const_cast<void*>(topk_ids.data_ptr()));
workspace.expert_weights = const_cast<void*>(topk_weights.data_ptr());
workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens.data_ptr());
workspace.expanded_idx_to_permuted_idx =
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr());
Expand Down Expand Up @@ -1490,6 +1507,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
}

private:
RoutingInputMode routing_input_mode_;
Optional<TensorView> hidden_states_scale;
TensorView gemm1_weights_scale;
Optional<TensorView> gemm1_bias;
Expand All @@ -1501,8 +1519,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
int32_t max_num_padded_tokens_gemm1{};
int32_t max_num_padded_tokens_gemm2{};
Optional<Tensor> gemm1_output_scale;
TensorView expert_indices;
TensorView expert_weights;
TensorView topk_ids; // [num_tokens, top_k] - pre-computed or output top-k expert indices
TensorView topk_weights; // [num_tokens, top_k] - pre-computed or output top-k routing weights

public:
Array<Tensor> run(int64_t moe_tactic, bool enable_pdl = true,
Expand All @@ -1515,21 +1533,45 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t routing_stream = get_stream(hidden_states.device());

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, static_cast<int*>(expert_indices.data_ptr()),
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(),
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype,
use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);
// Set routing kernel parameters based on mode (see RoutingInputMode enum for documentation)
int32_t* expert_ids_param = nullptr; // INPUT: pre-computed expert IDs (Mode 3 only)
void* expert_weights_param = nullptr; // INPUT or OUTPUT depending on mode

switch (routing_input_mode_) {
case RoutingInputMode::FromLogits:
// Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;

case RoutingInputMode::PackedPrecomputed:
// Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;

case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
}
Comment on lines +1540 to +1558
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Add a default case to handle invalid routing modes.

The switch statement lacks a default case. If routing_input_mode_ contains an invalid value (e.g., due to an unchecked cast from the int64_t parameter in trtllm_fp4_block_scale_moe), this results in undefined behavior with uninitialized expert_ids_param and expert_weights_param.

πŸ›‘οΈ Proposed fix
       case RoutingInputMode::UnpackedPrecomputed:
         // Mode 3: Both are INPUTS, kernel uses them directly
         expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
         expert_weights_param = topk_weights.data_ptr();
         break;
+
+      default:
+        TVM_FFI_LOG_AND_THROW(ValueError)
+            << "Invalid routing_input_mode: " << static_cast<int>(routing_input_mode_);
     }
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
switch (routing_input_mode_) {
case RoutingInputMode::FromLogits:
// Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::PackedPrecomputed:
// Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
}
switch (routing_input_mode_) {
case RoutingInputMode::FromLogits:
// Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::PackedPrecomputed:
// Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT)
expert_ids_param = nullptr;
expert_weights_param = topk_weights.data_ptr();
break;
case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
default:
TVM_FFI_LOG_AND_THROW(ValueError)
<< "Invalid routing_input_mode: " << static_cast<int>(routing_input_mode_);
}
πŸ€– Prompt for AI Agents
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1445 - 1463, The
switch over routing_input_mode_ in the trtllm_fp4_block_scale_moe kernel
launcher can leave expert_ids_param and expert_weights_param uninitialized for
invalid enum values; add a default case to the switch that sets expert_ids_param
= nullptr and expert_weights_param = nullptr, emits an error (or uses an
assertion) indicating an invalid RoutingInputMode, and returns or aborts early
from trtllm_fp4_block_scale_moe to avoid undefined behavior; reference the
switch on routing_input_mode_, and the variables expert_ids_param and
expert_weights_param when making the change.


routing_runner.run(args->routing_logits, args->routing_bias, args->num_tokens,
args->num_experts, args->top_k, args->n_group, args->topk_group,
args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, static_cast<int*>(topk_ids.data_ptr()),
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_ids_param,
expert_weights_param, static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt,
mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8,
static_cast<RoutingMethodType>(routing_method_type), routing_stream);

check_moe();
prepare_moe(moe_tactic);
Expand Down Expand Up @@ -1865,8 +1907,8 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
}

Array<Tensor> trtllm_fp4_block_scale_moe(
Optional<TensorView> routing_logits, TensorView expert_indices, TensorView expert_weights,
Optional<TensorView> routing_bias, TensorView hidden_states,
int64_t routing_input_mode, Optional<TensorView> routing_logits, TensorView topk_ids,
TensorView topk_weights, Optional<TensorView> routing_bias, TensorView hidden_states,
Optional<TensorView> hidden_states_scale, TensorView gemm1_weights,
TensorView gemm1_weights_scale, Optional<TensorView> gemm1_bias,
Optional<TensorView> gemm1_alpha, Optional<TensorView> gemm1_beta,
Expand Down Expand Up @@ -1979,10 +2021,11 @@ Array<Tensor> trtllm_fp4_block_scale_moe(

// Create and initialize launcher for this tile size
auto launcher = std::make_unique<FP4BlockScaleLauncher>(
routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights,
gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar,
output2_scales_scalar, expert_indices, expert_weights);
static_cast<RoutingInputMode>(routing_input_mode), routing_logits, routing_bias,
hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm1_bias,
gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias,
output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, topk_ids,
topk_weights);
launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true,
/*weight_layout=*/0, static_cast<ActivationType>(act_type), mDtypeAct,
mDtypeWeights);
Expand Down
21 changes: 14 additions & 7 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes,
int32_t* expertCountHistogram, int32_t* permutedIdxSize,
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
int32_t* permutedIdxToTokenIdx, int32_t* expertIds, void* expertWeights,
int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx,
int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt,
btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8,
RoutingMethodType routingMethodType, cudaStream_t stream) {
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22");
Expand All @@ -84,7 +84,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3

// input:
routingData.mPtrRoutingBias = routingBias;
routingData.mPtrScores = reinterpret_cast<float*>(routingLogits);
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores =
expertIds == nullptr ? reinterpret_cast<float*>(routingLogits) : nullptr;
routingData.mPtrTopKIds = expertIds;
routingData.mNumTokens = numTokens;
routingData.mNumExperts = numExperts;
routingData.mNumExpertGroups = nGroup;
Expand Down Expand Up @@ -121,7 +124,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrNumNonExitingCtas = numNonExitingCtas;

// input:
routingData.mPtrScores = routingLogits;
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr;
routingData.mPtrTopKIds = expertIds;
routingData.mNumTokens = numTokens;
routingData.mNumExperts = numExperts;
routingData.mTopK = topK;
Expand All @@ -147,7 +152,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive;
routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize;

routingData.mPtrScores = routingLogits;
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr;
routingData.mPtrTopKIds = expertIds;

//
// Outputs
Expand Down
Loading