Skip to content
Open
14 changes: 12 additions & 2 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ class FusedMoeLauncher {

check_routing_logits_shape();

// If routing_logits is not provided, we expect pre-computed routing
// This will be validated in the prepare_routing phase where expert_indexes/weights are checked
if (!routing_logits.has_value()) {
// Pre-computed routing mode - expert_indexes and expert_weights will be validated later
// They are allocated/set in prepare_routing() by derived classes
}

if (routing_bias.has_value()) {
check_routing_bias_shape();
}
Expand Down Expand Up @@ -340,10 +347,11 @@ class FusedMoeLauncher {
check_routing();
prepare_routing();

// Execute routing
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
cudaStream_t routing_stream = get_stream(hidden_states.device());

// Execute routing (handles both pre-computed and from-logits paths)
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
Expand All @@ -353,6 +361,7 @@ class FusedMoeLauncher {
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(),
nullptr /*expertIds - not used when computing from logits*/,
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
Expand Down Expand Up @@ -1353,6 +1362,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/,
static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(),
nullptr /*expertIds - not used when computing from logits*/,
static_cast<int*>(num_tokens_per_expert.data_ptr()),
static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()),
static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()),
Expand Down
21 changes: 14 additions & 7 deletions csrc/trtllm_fused_moe_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes,
int32_t* expertCountHistogram, int32_t* permutedIdxSize,
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds,
int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx,
int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt,
btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8,
RoutingMethodType routingMethodType, cudaStream_t stream) {
if (routingMethodType == RoutingMethodType::DeepSeekV3) {
FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8");
Expand All @@ -84,7 +84,10 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3

// input:
routingData.mPtrRoutingBias = routingBias;
routingData.mPtrScores = reinterpret_cast<float*>(routingLogits);
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores =
expertIds == nullptr ? reinterpret_cast<float*>(routingLogits) : nullptr;
routingData.mPtrTopKIds = expertIds;
routingData.mNumTokens = numTokens;
routingData.mNumExperts = numExperts;
routingData.mNumExpertGroups = nGroup;
Expand Down Expand Up @@ -121,7 +124,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mPtrNumNonExitingCtas = numNonExitingCtas;

// input:
routingData.mPtrScores = routingLogits;
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr;
routingData.mPtrTopKIds = expertIds;
routingData.mNumTokens = numTokens;
routingData.mNumExperts = numExperts;
routingData.mTopK = topK;
Expand All @@ -147,7 +152,9 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3
routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive;
routingData.mApplySoftmaxAfterTopK = routingMethodType == RoutingMethodType::Renormalize;

routingData.mPtrScores = routingLogits;
// Pre-computed routing support: when expertIds is provided, use it directly
routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr;
routingData.mPtrTopKIds = expertIds;

//
// Outputs
Expand Down
11 changes: 7 additions & 4 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,7 @@ def trtllm_fp4_block_scale_moe(
@flashinfer_api
def trtllm_fp4_block_scale_routed_moe(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: Optional[torch.Tensor],
Expand Down Expand Up @@ -2484,9 +2485,11 @@ def trtllm_fp4_block_scale_routed_moe(

Args:
topk_ids (torch.Tensor): shape [seq_len, top_k]
Tensor of top-k indices and expert weights. Dtype must be int32.
It must represent a packed value. The most significant 16/32 bits represent the score and
the least significant 16 bits represent the index of the chosen expert (unsigned).
Tensor of top-k expert indices. Dtype must be int32.
Each element contains the index of the selected expert.
topk_weights (torch.Tensor): shape [seq_len, top_k]
Tensor of top-k routing weights. Dtype must be bfloat16.
Each element contains the routing score/weight for the corresponding expert.
routing_bias (Optional[torch.Tensor]): shape [num_experts]
Tensor of routing bias. Can be None for some routing methods. Must be the same type as routing logits.
hidden_states (torch.Tensor): shape [seq_len, hidden_size // 2 if nvfp4 else hidden_size]
Expand Down Expand Up @@ -2546,7 +2549,7 @@ def trtllm_fp4_block_scale_routed_moe(
return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe(
None,
topk_ids,
None,
topk_weights,
routing_bias,
hidden_states,
hidden_states_scale,
Expand Down
10 changes: 5 additions & 5 deletions include/flashinfer/trtllm/fused_moe/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ class Runner {
int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes,
int32_t* expertCountHistogram, int32_t* permutedIdxSize,
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas,
batchedGemm::trtllm::gen::Dtype dtypeElt, batchedGemm::trtllm::gen::Dtype dtypeBias,
bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType,
cudaStream_t stream);
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds,
int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit,
int32_t* numNonExitingCtas, batchedGemm::trtllm::gen::Dtype dtypeElt,
batchedGemm::trtllm::gen::Dtype dtypeBias, bool useRoutingScalesOnInput,
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream);

private:
int32_t mTileTokensDim{8};
Expand Down
Loading