-
Notifications
You must be signed in to change notification settings - Fork 759
feat: Expose unpacked topk weights for routed moe (fp4) #2425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
71e4dc3
675f005
0242a1f
973ad87
7d21bbf
0171f56
45a20c4
5f8a1cd
220b23f
c3ef799
37b6bd9
7ccb196
dc845e7
d664651
6daa827
473fab8
e7a98a9
dafc67d
a06a0fd
05205e1
3c62cf8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -41,6 +41,12 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using tvm::ffi::Array; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using tvm::ffi::Optional; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enum class RoutingInputMode { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FromLogits, // Mode 1: Compute routing from logits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PackedPrecomputed, // Mode 2: Pre-computed with packed (score << 16 | id) format | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| UnpackedPrecomputed // Mode 3: Pre-computed with separate topk_ids and topk_weights | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enum class Fp8QuantizationType { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NoneFp8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DeepSeekFp8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -371,25 +377,33 @@ class FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| check_routing(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prepare_routing(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cudaStream_t routing_stream = get_stream(hidden_states.device()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Execute routing | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cudaStream_t routing_stream = get_stream(hidden_states.device()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| routing_runner.run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->routed_scaling_factor, static_cast<int*>(expert_indexes.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expert_count_histogram.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(total_num_padded_tokens.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // This base class only supports Mode 1 (FromLogits) - compute routing from logits | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| constexpr RoutingInputMode routing_input_mode = RoutingInputMode::FromLogits; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Mode 1: expertIds is nullptr, expertWeights is OUTPUT buffer for computed weights | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int32_t* expert_ids_param = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void* expert_weights_param = expert_weights.data_ptr(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| routing_runner.run(args->routing_logits, args->routing_bias, args->num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->num_experts, args->top_k, args->n_group, args->topk_group, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->local_expert_offset, args->local_num_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->routed_scaling_factor, static_cast<int*>(expert_indexes.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expert_count_histogram.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(total_num_padded_tokens.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_ids_param, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_weights_param, static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| check_moe(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prepare_moe(moe_tactic); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1052,6 +1066,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // When using pre-computed routing, pass nullptr as routing_logits to tell the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // FP8 only supports Mode 1 (FromLogits) and Mode 2 (PackedPrecomputed), so expertIds is nullptr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| routing_runner.run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1060,8 +1075,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(total_num_padded_tokens.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), workspace.expert_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr, // expertIds - FP8 doesn't support UnpackedPrecomputed mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.expert_weights, static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1277,19 +1293,21 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FP4BlockScaleLauncher( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& routing_logits, Optional<TensorView> const& routing_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView const& hidden_states, Optional<TensorView> const& hidden_states_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& gemm1_bias, Optional<TensorView> const& gemm1_alpha, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& gemm1_beta, Optional<TensorView> const& gemm1_clamp_limit, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& gemm2_bias, Optional<TensorView> const& output1_scales_scalar, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RoutingInputMode routing_input_mode, Optional<TensorView> const& routing_logits, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& routing_bias, TensorView const& hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& hidden_states_scale, TensorView const& gemm1_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView const& gemm1_weights_scale, Optional<TensorView> const& gemm1_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& gemm1_alpha, Optional<TensorView> const& gemm1_beta, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& gemm1_clamp_limit, TensorView const& gemm2_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView const& gemm2_weights_scale, Optional<TensorView> const& gemm2_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& output1_scales_scalar, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& output1_scales_gate_scalar, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& output2_scales_scalar, TensorView const& expert_indices, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView const& expert_weights) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> const& output2_scales_scalar, TensorView const& topk_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView const& topk_weights) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output2_scales_scalar), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| routing_input_mode_(routing_input_mode), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states_scale(hidden_states_scale), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm1_weights_scale(gemm1_weights_scale), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm1_bias(gemm1_bias), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1298,8 +1316,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm1_clamp_limit(gemm1_clamp_limit), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm2_weights_scale(gemm2_weights_scale), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm2_bias(gemm2_bias), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_indices(expert_indices), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_weights(expert_weights) {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| topk_ids(topk_ids), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| topk_weights(topk_weights) {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1360,9 +1378,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.total_num_padded_tokens = static_cast<int*>(total_num_padded_tokens.data_ptr()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.total_max_padded_tokens = max_num_padded_tokens; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.ProjUpTileN = tile_tokens_dim; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.routing_expert_indexes = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(const_cast<void*>(expert_indices.data_ptr())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.routing_expert_indexes = static_cast<int*>(const_cast<void*>(topk_ids.data_ptr())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.expert_weights = const_cast<void*>(topk_weights.data_ptr()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.permuted_idx_size = static_cast<int*>(total_num_padded_tokens.data_ptr()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| workspace.expanded_idx_to_permuted_idx = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1490,6 +1507,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| RoutingInputMode routing_input_mode_; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> hidden_states_scale; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView gemm1_weights_scale; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> gemm1_bias; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1501,8 +1519,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int32_t max_num_padded_tokens_gemm1{}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int32_t max_num_padded_tokens_gemm2{}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<Tensor> gemm1_output_scale; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView expert_indices; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView expert_weights; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView topk_ids; // [num_tokens, top_k] - pre-computed or output top-k expert indices | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView topk_weights; // [num_tokens, top_k] - pre-computed or output top-k routing weights | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Array<Tensor> run(int64_t moe_tactic, bool enable_pdl = true, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1515,21 +1533,45 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cudaStream_t routing_stream = get_stream(hidden_states.device()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| routing_runner.run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->routed_scaling_factor, static_cast<int*>(expert_indices.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expert_count_histogram.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(total_num_padded_tokens.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Set routing kernel parameters based on mode (see RoutingInputMode enum for documentation) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int32_t* expert_ids_param = nullptr; // INPUT: pre-computed expert IDs (Mode 3 only) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void* expert_weights_param = nullptr; // INPUT or OUTPUT depending on mode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| switch (routing_input_mode_) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| case RoutingInputMode::FromLogits: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Mode 1: Kernel computes routing, writes weights to expert_weights_param (OUTPUT) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_ids_param = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_weights_param = topk_weights.data_ptr(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| case RoutingInputMode::PackedPrecomputed: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Mode 2: Kernel unpacks from topk_ids, writes weights to expert_weights_param (OUTPUT) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_ids_param = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_weights_param = topk_weights.data_ptr(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| case RoutingInputMode::UnpackedPrecomputed: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Mode 3: Both are INPUTS, kernel uses them directly | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_weights_param = topk_weights.data_ptr(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1540
to
+1558
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a default case to handle invalid routing modes. The switch statement lacks a π‘οΈ Proposed fix case RoutingInputMode::UnpackedPrecomputed:
// Mode 3: Both are INPUTS, kernel uses them directly
expert_ids_param = static_cast<int32_t*>(topk_ids.data_ptr());
expert_weights_param = topk_weights.data_ptr();
break;
+
+ default:
+ TVM_FFI_LOG_AND_THROW(ValueError)
+ << "Invalid routing_input_mode: " << static_cast<int>(routing_input_mode_);
}π Committable suggestion
Suggested change
π€ Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| routing_runner.run(args->routing_logits, args->routing_bias, args->num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->num_experts, args->top_k, args->n_group, args->topk_group, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->local_expert_offset, args->local_num_experts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args->routed_scaling_factor, static_cast<int*>(topk_ids.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expert_count_histogram.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(total_num_padded_tokens.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nullptr /*permuted_idx_to_expanded_idx*/, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(permuted_idx_to_token_idx.data_ptr()), expert_ids_param, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expert_weights_param, static_cast<int*>(num_tokens_per_expert.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_batch_idx.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(cta_idx_xy_to_mn_limit.data_ptr()), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<int*>(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mRoutingBiasDtype, use_routing_scales_on_input, use_deep_seek_fp8, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<RoutingMethodType>(routing_method_type), routing_stream); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| check_moe(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prepare_moe(moe_tactic); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1865,8 +1907,8 @@ Array<Tensor> trtllm_fp8_block_scale_moe( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Array<Tensor> trtllm_fp4_block_scale_moe( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> routing_logits, TensorView expert_indices, TensorView expert_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> routing_bias, TensorView hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int64_t routing_input_mode, Optional<TensorView> routing_logits, TensorView topk_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView topk_weights, Optional<TensorView> routing_bias, TensorView hidden_states, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> hidden_states_scale, TensorView gemm1_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TensorView gemm1_weights_scale, Optional<TensorView> gemm1_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Optional<TensorView> gemm1_alpha, Optional<TensorView> gemm1_beta, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1979,10 +2021,11 @@ Array<Tensor> trtllm_fp4_block_scale_moe( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Create and initialize launcher for this tile size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto launcher = std::make_unique<FP4BlockScaleLauncher>( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output2_scales_scalar, expert_indices, expert_weights); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static_cast<RoutingInputMode>(routing_input_mode), routing_logits, routing_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm1_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, topk_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| topk_weights); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /*weight_layout=*/0, static_cast<ActivationType>(act_type), mDtypeAct, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mDtypeWeights); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.