diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index f1a50f6dba..8241d70045 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -41,6 +41,12 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; +enum class RoutingInputMode { + FromLogits, // Mode 1: Compute routing from logits + PackedPrecomputed, // Mode 2: Pre-computed with packed (score << 16 | id) format + UnpackedPrecomputed // Mode 3: Pre-computed with separate topk_ids and topk_weights +}; + enum class Fp8QuantizationType { NoneFp8, DeepSeekFp8, @@ -371,25 +377,33 @@ class FusedMoeLauncher { check_routing(); prepare_routing(); + cudaStream_t routing_stream = get_stream(hidden_states.device()); + // Execute routing tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - cudaStream_t routing_stream = get_stream(hidden_states.device()); - routing_runner.run( - args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, - args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, - args->routed_scaling_factor, static_cast(expert_indexes.data_ptr()), - static_cast(expert_count_histogram.data_ptr()), - static_cast(total_num_padded_tokens.data_ptr()), - static_cast(expanded_idx_to_permuted_idx.data_ptr()), - nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, - static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), - static_cast(num_tokens_per_expert.data_ptr()), - static_cast(cta_idx_xy_to_batch_idx.data_ptr()), - static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, - use_routing_scales_on_input, use_deep_seek_fp8, - static_cast(routing_method_type), routing_stream); + // This base class only supports Mode 1 (FromLogits) - compute routing from logits + constexpr RoutingInputMode routing_input_mode = RoutingInputMode::FromLogits; + + // Mode 1: expertIds is nullptr, expertWeights is OUTPUT buffer for computed weights + int32_t* expert_ids_param = nullptr; + void* expert_weights_param = expert_weights.data_ptr(); + + routing_runner.run(args->routing_logits, args->routing_bias, args->num_tokens, + args->num_experts, args->top_k, args->n_group, args->topk_group, + args->local_expert_offset, args->local_num_experts, + args->routed_scaling_factor, static_cast(expert_indexes.data_ptr()), + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), expert_ids_param, + expert_weights_param, static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, + mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(routing_method_type), routing_stream); check_moe(); prepare_moe(moe_tactic); @@ -1052,6 +1066,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; // When using pre-computed routing, pass nullptr as routing_logits to tell the // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes + // FP8 only supports Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), so expertIds is nullptr routing_runner.run( use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, @@ -1060,8 +1075,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static_cast(total_num_padded_tokens.data_ptr()), static_cast(expanded_idx_to_permuted_idx.data_ptr()), nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, - static_cast(permuted_idx_to_token_idx.data_ptr()), workspace.expert_weights, - static_cast(num_tokens_per_expert.data_ptr()), + static_cast(permuted_idx_to_token_idx.data_ptr()), + nullptr, // expertIds - FP8 doesn't support UnpackedPrecomputed mode + workspace.expert_weights, static_cast(num_tokens_per_expert.data_ptr()), static_cast(cta_idx_xy_to_batch_idx.data_ptr()), static_cast(cta_idx_xy_to_mn_limit.data_ptr()), static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, @@ -1277,19 +1293,21 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { } FP4BlockScaleLauncher( - Optional const& routing_logits, Optional const& routing_bias, - TensorView const& hidden_states, Optional const& hidden_states_scale, - TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale, - Optional const& gemm1_bias, Optional const& gemm1_alpha, - Optional const& gemm1_beta, Optional const& gemm1_clamp_limit, - TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale, - Optional const& gemm2_bias, Optional const& output1_scales_scalar, + RoutingInputMode routing_input_mode, Optional const& routing_logits, + Optional const& routing_bias, TensorView const& hidden_states, + Optional const& hidden_states_scale, TensorView const& gemm1_weights, + TensorView const& gemm1_weights_scale, Optional const& gemm1_bias, + Optional const& gemm1_alpha, Optional const& gemm1_beta, + Optional const& gemm1_clamp_limit, TensorView const& gemm2_weights, + TensorView const& gemm2_weights_scale, Optional const& gemm2_bias, + Optional const& output1_scales_scalar, Optional const& output1_scales_gate_scalar, - Optional const& output2_scales_scalar, TensorView const& expert_indices, - TensorView const& expert_weights) + Optional const& output2_scales_scalar, TensorView const& topk_ids, + TensorView const& topk_weights) : FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar), + routing_input_mode_(routing_input_mode), hidden_states_scale(hidden_states_scale), gemm1_weights_scale(gemm1_weights_scale), gemm1_bias(gemm1_bias), @@ -1298,8 +1316,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { gemm1_clamp_limit(gemm1_clamp_limit), gemm2_weights_scale(gemm2_weights_scale), gemm2_bias(gemm2_bias), - expert_indices(expert_indices), - expert_weights(expert_weights) {} + topk_ids(topk_ids), + topk_weights(topk_weights) {} void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, @@ -1360,9 +1378,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = max_num_padded_tokens; workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = - static_cast(const_cast(expert_indices.data_ptr())); - workspace.expert_weights = const_cast(expert_weights.data_ptr()); + workspace.routing_expert_indexes = static_cast(const_cast(topk_ids.data_ptr())); + workspace.expert_weights = const_cast(topk_weights.data_ptr()); workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); workspace.expanded_idx_to_permuted_idx = static_cast(expanded_idx_to_permuted_idx.data_ptr()); @@ -1490,6 +1507,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { } private: + RoutingInputMode routing_input_mode_; Optional hidden_states_scale; TensorView gemm1_weights_scale; Optional gemm1_bias; @@ -1501,8 +1519,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { int32_t max_num_padded_tokens_gemm1{}; int32_t max_num_padded_tokens_gemm2{}; Optional gemm1_output_scale; - TensorView expert_indices; - TensorView expert_weights; + TensorView topk_ids; // [num_tokens, top_k] - pre-computed or output top-k expert indices + TensorView topk_weights; // [num_tokens, top_k] - pre-computed or output top-k routing weights public: Array run(int64_t moe_tactic, bool enable_pdl = true, @@ -1515,21 +1533,45 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); cudaStream_t routing_stream = get_stream(hidden_states.device()); - routing_runner.run( - args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, - args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, - args->routed_scaling_factor, static_cast(expert_indices.data_ptr()), - static_cast(expert_count_histogram.data_ptr()), - static_cast(total_num_padded_tokens.data_ptr()), - static_cast(expanded_idx_to_permuted_idx.data_ptr()), - nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, - static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), - static_cast(num_tokens_per_expert.data_ptr()), - static_cast(cta_idx_xy_to_batch_idx.data_ptr()), - static_cast(cta_idx_xy_to_mn_limit.data_ptr()), - static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, - use_routing_scales_on_input, use_deep_seek_fp8, - static_cast(routing_method_type), routing_stream); + // Set routing kernel parameters based on mode (see RoutingInputMode enum for documentation) + int32_t* expert_ids_param = nullptr; // INPUT: pre-computed expert IDs (Mode 3 only) + void* expert_weights_param = nullptr; // INPUT or OUTPUT depending on mode + + switch (routing_input_mode_) { + case RoutingInputMode::FromLogits: + // Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT) + expert_ids_param = nullptr; + expert_weights_param = topk_weights.data_ptr(); + break; + + case RoutingInputMode::PackedPrecomputed: + // Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT) + expert_ids_param = nullptr; + expert_weights_param = topk_weights.data_ptr(); + break; + + case RoutingInputMode::UnpackedPrecomputed: + // Mode 3: Both are INPUTS, kernel uses them directly + expert_ids_param = static_cast(topk_ids.data_ptr()); + expert_weights_param = topk_weights.data_ptr(); + break; + } + + routing_runner.run(args->routing_logits, args->routing_bias, args->num_tokens, + args->num_experts, args->top_k, args->n_group, args->topk_group, + args->local_expert_offset, args->local_num_experts, + args->routed_scaling_factor, static_cast(topk_ids.data_ptr()), + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*permuted_idx_to_expanded_idx*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), expert_ids_param, + expert_weights_param, static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, + mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(routing_method_type), routing_stream); check_moe(); prepare_moe(moe_tactic); @@ -1865,8 +1907,8 @@ Array trtllm_fp8_block_scale_moe( } Array trtllm_fp4_block_scale_moe( - Optional routing_logits, TensorView expert_indices, TensorView expert_weights, - Optional routing_bias, TensorView hidden_states, + int64_t routing_input_mode, Optional routing_logits, TensorView topk_ids, + TensorView topk_weights, Optional routing_bias, TensorView hidden_states, Optional hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, Optional gemm1_bias, Optional gemm1_alpha, Optional gemm1_beta, @@ -1979,10 +2021,11 @@ Array trtllm_fp4_block_scale_moe( // Create and initialize launcher for this tile size auto launcher = std::make_unique( - routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, - gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, - gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, - output2_scales_scalar, expert_indices, expert_weights); + static_cast(routing_input_mode), routing_logits, routing_bias, + hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm1_bias, + gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, + output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, topk_ids, + topk_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, /*weight_layout=*/0, static_cast(act_type), mDtypeAct, mDtypeWeights); diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index a1bf8139cc..ec0ae3465d 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -54,10 +54,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes, int32_t* expertCountHistogram, int32_t* permutedIdxSize, int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, - int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, - int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, - int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, - bool useRoutingScalesOnInput, bool useDeepSeekFp8, + int32_t* permutedIdxToTokenIdx, int32_t* expertIds, void* expertWeights, + int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, + int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt, + btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); @@ -84,7 +84,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // input: routingData.mPtrRoutingBias = routingBias; - routingData.mPtrScores = reinterpret_cast(routingLogits); + // Pre-computed routing support: when expertIds is provided, use it directly + routingData.mPtrScores = + expertIds == nullptr ? reinterpret_cast(routingLogits) : nullptr; + routingData.mPtrTopKIds = expertIds; routingData.mNumTokens = numTokens; routingData.mNumExperts = numExperts; routingData.mNumExpertGroups = nGroup; @@ -121,7 +124,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mPtrNumNonExitingCtas = numNonExitingCtas; // input: - routingData.mPtrScores = routingLogits; + // Pre-computed routing support: when expertIds is provided, use it directly + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + routingData.mPtrTopKIds = expertIds; routingData.mNumTokens = numTokens; routingData.mNumExperts = numExperts; routingData.mTopK = topK; @@ -147,7 +152,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive; routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize; - routingData.mPtrScores = routingLogits; + // Pre-computed routing support: when expertIds is provided, use it directly + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + routingData.mPtrTopKIds = expertIds; // // Outputs diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 4002b716d9..7cda5ce795 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -173,6 +173,26 @@ class WeightLayout(IntEnum): BlockMajorK = 2 +# Routing input modes for FusedMoE launcher +# Please keep this in sync with the counterpart defined in csrc/trtllm_fused_moe_kernel_launcher.cu +class RoutingInputMode(IntEnum): + # Mode 1: Compute routing from logits + # - Input: routing_logits tensor provided + # - topk_ids: OUTPUT buffer for computed expert indices + # - topk_weights: OUTPUT buffer for computed weights + FromLogits = 0 + # Mode 2: Pre-computed routing with packed format + # - Input: topk_ids contains packed (score << 16 | expert_id) + # - topk_ids: INPUT with packed values + # - topk_weights: OUTPUT buffer for extracted weights + PackedPrecomputed = 1 + # Mode 3: Pre-computed routing with separate tensors + # - Input: separate topk_ids (expert indices) and topk_weights (routing weights) + # - topk_ids: INPUT - pre-computed expert indices + # - topk_weights: INPUT - pre-computed routing weights + UnpackedPrecomputed = 2 + + # The type of gated activation function # Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h class GatedActType(IntEnum): @@ -965,7 +985,7 @@ class MoERunner(TunableRunner): ), # topk_ids buffer. empty since routing_logits is used. [num_tokens, topk] lambda shapes, dtype, device: torch.empty( shapes, device=device, dtype=dtype - ), # expert_weights buffer. empty since routing_logits is used. [num_tokens, topk] + ), # topk_weights buffer. empty since routing_logits is used. [num_tokens, topk] lambda shapes, dtype, device: torch.randn(shapes, device=device).to( dtype ), # hidden_states, [num_tokens, hidden_size] @@ -1036,7 +1056,7 @@ def get_valid_tactics( output, routing_logits, topk_ids, - expert_weights, + topk_weights, hidden_states, *extra_inputs, ) = inputs @@ -1077,7 +1097,7 @@ def forward( output, routing_logits, topk_ids, - expert_weights, + topk_weights, hidden_states, *extra_inputs, ) = inputs @@ -1096,8 +1116,8 @@ def forward( assert topk_ids.shape[0] == num_tokens, ( "topk_ids's first dimension must be batch size." ) - assert expert_weights.shape[0] == num_tokens, ( - "expert_weights's first dimension must be batch size." + assert topk_weights.shape[0] == num_tokens, ( + "topk_weights's first dimension must be batch size." ) assert hidden_states.shape[0] == num_tokens, ( "hidden_states's first dimension must be batch size." @@ -1164,7 +1184,7 @@ def forward( moe_op.trtllm_fp8_block_scale_moe( routing_logits, topk_ids, - expert_weights, + topk_weights, kwargs["routing_bias"], hidden_states, current_hidden_states_scale, @@ -1246,10 +1266,12 @@ def forward( [-1, -1] if tactic == -1 else tactic, ) else: + # Tuning always uses Mode 1 (FromLogits) moe_op.trtllm_fp4_block_scale_moe( + RoutingInputMode.FromLogits, routing_logits, topk_ids, - expert_weights, + topk_weights, kwargs["routing_bias"], hidden_states, hidden_states_scale, # hidden_states_scale @@ -1349,7 +1371,7 @@ def trtllm_bf16_moe_op( topk_ids = torch.empty( num_tokens, top_k, dtype=torch.int32, device=hidden_states.device ) - expert_weights = torch.empty( + topk_weights = torch.empty( num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device ) @@ -1369,7 +1391,7 @@ def trtllm_bf16_moe_op( activation_type=ActivationType.Swiglu, # Default for BF16 ) - inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] + inputs = [output, routing_logits, topk_ids, topk_weights, hidden_states] _, tactic = tuner.choose_one( "flashinfer::trtllm_bf16_moe", @@ -1494,7 +1516,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( topk_ids = torch.empty( num_tokens, top_k, dtype=torch.int32, device=hidden_states.device ) - expert_weights = torch.empty( + topk_weights = torch.empty( num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device ) @@ -1514,7 +1536,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( activation_type=activation_type, ) - inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] + inputs = [output, routing_logits, topk_ids, topk_weights, hidden_states] _, tactic = tuner.choose_one( "flashinfer::trtllm_fp8_per_tensor_scale_moe", @@ -1660,15 +1682,15 @@ def trtllm_fp8_block_scale_moe_op( num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device ) if routing_logits is not None: - # When routing_logits is provided, we must pass topk_ids/expert_weights with no allocation + # When routing_logits is provided, allocate empty buffers (kernel will fill them) topk_ids = torch.empty(0, dtype=torch.int32, device=hidden_states.device) expert_weights = torch.empty( 0, dtype=routing_dtype, device=hidden_states.device ) else: - # When routing_logits is provided, we either have topk_ids/expert_weights, - # packed into a single tensor as topk_id - # or have them individually as topk_ids and expert_weights respectively + # When routing_logits is None, we have pre-computed routing: + # - packed format: topk_ids contains (score << 16 | expert_id) + # - unpacked format: separate topk_ids and expert_weights topk_ids = topk_ids expert_weights = ( expert_weights @@ -1812,9 +1834,10 @@ def _fake_trtllm_fp8_block_scale_moe( mutates_args=(""), ) def trtllm_fp4_block_scale_moe_op( + routing_input_mode: int, routing_logits: Optional[torch.Tensor], topk_ids: Optional[torch.Tensor], - expert_weights: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], @@ -1859,14 +1882,24 @@ def trtllm_fp4_block_scale_moe_op( num_tokens = hidden_states.shape[0] # workspace buffers required by trtllm-gen - if topk_ids is None: - topk_ids = torch.empty( - num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + # For Mode 3 (UnpackedPrecomputed), topk_ids and topk_weights are user-provided INPUTS + if routing_input_mode == RoutingInputMode.UnpackedPrecomputed: + assert topk_ids is not None, ( + "topk_ids must be provided for UnpackedPrecomputed mode" ) - if expert_weights is None: - expert_weights = torch.empty( - num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device + assert topk_weights is not None, ( + "topk_weights must be provided for UnpackedPrecomputed mode" ) + else: + # For Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), allocate OUTPUT buffers + if topk_ids is None: + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + if topk_weights is None: + topk_weights = torch.empty( + num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device + ) if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) if output is None: @@ -1916,7 +1949,7 @@ def trtllm_fp4_block_scale_moe_op( if routing_logits is None else routing_logits, topk_ids, - expert_weights, + topk_weights, hidden_states, ] if hidden_states_scale is not None: @@ -1953,9 +1986,10 @@ def trtllm_fp4_block_scale_moe_op( # Call the C++ function for block scale MoE intermediate_output = moe_op.trtllm_fp4_block_scale_moe( + routing_input_mode, routing_logits, topk_ids, - expert_weights, + topk_weights, routing_bias, hidden_states, hidden_states_scale, @@ -1991,15 +2025,16 @@ def trtllm_fp4_block_scale_moe_op( else: return [ torch.from_dlpack(intermediate_output[0]), - expert_weights, + topk_weights, torch.from_dlpack(intermediate_output[2]), ] @register_fake_op("flashinfer::trtllm_fp4_block_scale_moe") def _fake_trtllm_fp4_block_scale_moe( + routing_input_mode: int, routing_logits: torch.Tensor, topk_ids: Optional[torch.Tensor], - expert_weights: Optional[torch.Tensor], + topk_weights: Optional[torch.Tensor], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, @@ -2074,7 +2109,7 @@ def trtllm_mxint4_block_scale_moe_op( topk_ids = torch.empty( num_tokens, top_k, dtype=torch.int32, device=hidden_states.device ) - expert_weights = torch.empty( + topk_weights = torch.empty( num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device ) if enable_pdl is None: @@ -2108,7 +2143,7 @@ def trtllm_mxint4_block_scale_moe_op( output, routing_logits, topk_ids, - expert_weights, + topk_weights, hidden_states, ] @@ -2697,9 +2732,10 @@ def trtllm_fp4_block_scale_moe( Optional inplace output tensor. Returns: List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. - Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. + Otherwise, returns intermediate results (gemm2_output, topk_weights, expanded_idx_to_permuted_idx) that need further processing. """ return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( + RoutingInputMode.FromLogits, routing_logits, None, None, @@ -2737,7 +2773,7 @@ def trtllm_fp4_block_scale_moe( @flashinfer_api def trtllm_fp4_block_scale_routed_moe( - topk_ids: torch.Tensor, + topk_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: Optional[torch.Tensor], @@ -2768,13 +2804,20 @@ def trtllm_fp4_block_scale_routed_moe( output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: - """FP4 block scale MoE operation. + """FP4 block scale MoE operation with pre-computed routing. + + This function supports two pre-computed routing formats: + 1. Packed format: topk_ids is a single tensor with packed (score << 16 | expert_id) + 2. Unpacked format: topk_ids is a tuple of (topk_ids, topk_weights) tensors Args: - topk_ids (torch.Tensor): shape [seq_len, top_k] - Tensor of top-k indices and expert weights. Dtype must be int32. - It must represent a packed value. The most significant 16/32 bits represent the score and - the least significant 16 bits represent the index of the chosen expert (unsigned). + topk_ids (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): + Either a single tensor or a tuple of two tensors: + - Single tensor (packed format): shape [seq_len, top_k], dtype int32. + Must be packed value with (score << 16 | expert_id). + - Tuple (unpacked format): (topk_ids, topk_weights) where + topk_ids has shape [seq_len, top_k], dtype int32 (plain expert indices) + topk_weights has shape [seq_len, top_k], dtype bfloat16 (routing weights) routing_bias (Optional[torch.Tensor]): shape [num_experts] Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits. hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size] @@ -2831,12 +2874,24 @@ def trtllm_fp4_block_scale_routed_moe( Returns: List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. - Otherwise, returns intermediate results (gemm2_output, undefined, expanded_idx_to_permuted_idx) that need further processing. + Otherwise, returns intermediate results (gemm2_output, topk_weights, expanded_idx_to_permuted_idx) that need further processing. """ + # Determine routing mode based on input format + if isinstance(topk_ids, tuple): + # Unpacked format: (topk_ids, topk_weights) + topk_ids_tensor, topk_weights = topk_ids + routing_mode = RoutingInputMode.UnpackedPrecomputed + else: + # Packed format: single tensor with (score << 16 | expert_id) + topk_ids_tensor = topk_ids + topk_weights = None + routing_mode = RoutingInputMode.PackedPrecomputed + return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( + routing_mode, None, - topk_ids, - None, + topk_ids_tensor, + topk_weights, routing_bias, hidden_states, hidden_states_scale, diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 46617e5dbd..4d899fe2f6 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -124,11 +124,11 @@ class Runner { int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes, int32_t* expertCountHistogram, int32_t* permutedIdxSize, int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, - int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, - int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, - batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias, - bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, - cudaStream_t stream); + int32_t* permutedIdxToTokenIdx, int32_t* expertIds, void* expertWeights, + int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, + int32_t* numNonExitingCtas, batchedGemm::trtllm::gen::Dtype dtypeElt, + batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput, + bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream); private: int32_t mTileTokensDim{8}; diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index a5272ceb36..3abe177fc9 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -55,6 +55,7 @@ ], ) @pytest.mark.parametrize("quant_mode", ["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"]) +@pytest.mark.parametrize("routing_format", ["packed", "unpacked"]) def test_trtllm_gen_routed_fused_moe( num_tokens: int, hidden_size: int, @@ -63,6 +64,7 @@ def test_trtllm_gen_routed_fused_moe( num_experts: int, routing_method_type: RoutingMethodType, quant_mode: Literal["NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16"], + routing_format: Literal["packed", "unpacked"], ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] not in [10]: @@ -202,16 +204,22 @@ def test_trtllm_gen_routed_fused_moe( routing_logits, top_k, num_experts, 8 ) topk_ids = permute_info["topKIndices"].to(torch.int32) - expert_weights = expert_weights.view(num_tokens, num_experts)[ + topk_weights = expert_weights.view(num_tokens, num_experts)[ torch.arange(num_tokens).unsqueeze(1), topk_ids ].to(torch.bfloat16) - packed_tensor = (topk_ids.to(torch.int32) << 16) | expert_weights.to( - torch.bfloat16 - ).view(torch.int16) + # Prepare routing input based on format + if routing_format == "packed": + # Packed format: (score << 16 | expert_id) + routing_input = (topk_ids.to(torch.int32) << 16) | topk_weights.view( + torch.int16 + ) + else: + # Unpacked format: (topk_ids, topk_weights) tuple + routing_input = (topk_ids, topk_weights) output = trtllm_fp4_block_scale_routed_moe( - packed_tensor, + routing_input, None, # routing_bias hidden_states, hidden_states_scale, @@ -341,19 +349,17 @@ def test_trtllm_gen_fp8_routed_fused_moe( ).to(torch.float) # Compute routing using reference implementation - permute_info, expert_weights_ref = routing_reference_renormalize( + permute_info, topk_weights_ref = routing_reference_renormalize( routing_logits, top_k, num_experts, 8 ) topk_ids = permute_info["topKIndices"].to(torch.int32) - expert_weights = expert_weights_ref.view(num_tokens, num_experts)[ + topk_weights = topk_weights_ref.view(num_tokens, num_experts)[ torch.arange(num_tokens, device=device).unsqueeze(1), topk_ids ].to(torch.bfloat16) - # Pack topk_ids and expert_weights into single tensor + # Pack topk_ids and topk_weights into single tensor # Format: (expert_id << 16) | (weight_bf16.view(int16)) - packed_topk_ids = (topk_ids << 16) | expert_weights.view(torch.int16).to( - torch.int32 - ) + packed_topk_ids = (topk_ids << 16) | topk_weights.view(torch.int16).to(torch.int32) # Run with pre-computed routing (packed format) output = trtllm_fp8_block_scale_routed_moe(