diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 952b479a1d..2a991829dd 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -11,7 +11,7 @@ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time -from flashinfer.utils import device_support_pdl, calculate_tile_tokens_dim +from flashinfer.utils import device_support_pdl def bench_trtllm_gen_fused_moe_autotuner( @@ -99,9 +99,6 @@ def bench_trtllm_gen_fused_moe_autotuner( bias13 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 bias2 = torch.randn(num_experts, intermediate_size * 2, device=device) * 10 - tile_tokens_dim = calculate_tile_tokens_dim( - num_tokens, num_experts, top_k, 64 if quant_mode == "MxFP4xBf16" else 128 - ) output1_scale_scalar = torch.tensor( [hidden_states_global_scale * w13_global_scale] * num_experts, device=device ) @@ -136,7 +133,7 @@ def bench_trtllm_gen_fused_moe_autotuner( 0, # local_expert_offset num_experts, None, # routed_scaling_factor - tile_tokens_dim, + None, # tile_tokens_dim RoutingMethodType.Renormalize.value, True, enable_pdl, diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 22c1e8e51e..538dc92725 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h" @@ -37,6 +39,41 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; +// Utility function to compute the next power of two +inline int32_t nextPowerOfTwo(float value) { + int32_t n = static_cast(std::ceil(value)); + if (n <= 1) return 1; + + // If n is already a power of 2, return it + if ((n & (n - 1)) == 0) return n; + + // Find the next power of 2 + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n++; + + return n; +} + +std::set computeSelectedTileN(std::vector const& supported_tile_nums, + int64_t const num_tokens, int64_t const top_k, + int64_t const num_local_experts) { + float const avg_tokens_per_expert = static_cast(num_tokens * top_k) / num_local_experts; + int32_t tile_tokens_dim = std::clamp(nextPowerOfTwo(avg_tokens_per_expert), + supported_tile_nums.front(), supported_tile_nums.back()); + + std::set selected_tile_nums = { + std::max(supported_tile_nums.front(), tile_tokens_dim / 2), tile_tokens_dim, + std::min(supported_tile_nums.back(), tile_tokens_dim * 2), + std::min(supported_tile_nums.back(), tile_tokens_dim * 4)}; + + return selected_tile_nums; +} + void trtllm_fp8_per_tensor_scale_moe_launcher( TensorView routing_logits, Optional routing_bias, TensorView hidden_states, TensorView gemm1_weights, TensorView output1_scales_scalar, @@ -46,7 +83,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, Optional const routed_scaling_factor, bool const use_routing_scales_on_input, int64_t const tile_tokens_dim, - int64_t const routing_method_type, bool enable_pdl) { + int64_t const routing_method_type, + tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex, + bool enable_pdl) { static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, @@ -124,6 +163,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } + args.mDtypeOut = btg::Dtype::Bfloat16; // Output is always bfloat16 for fp8 per-tensor scale args.routing_logits = routing_logits.data_ptr(); auto const routing_bias_dtype = @@ -158,6 +198,13 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); + int32_t max_num_padded_tokens_gemm1 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.intermediate_size, btg::dtypeGetNumBits(args.mDtypeElt)); + int32_t max_num_padded_tokens_gemm2 = + tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( + max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); + Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits.device()); Tensor expanded_idx_to_permuted_idx = alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits.device()); @@ -174,20 +221,17 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits.device()); // allocate workspace for activation/gemm/finalize kernels - // Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); - // Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); - Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, + Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, + hidden_states.device()); + Tensor gemm1_output_scale = + alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size}, + dl_uint8, hidden_states.device()); + Tensor activation_output_scale = alloc_tensor( + {intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device()); + Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, hidden_states.device()); - Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states.device()); - Tensor activation_output = - alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states.device()); - Tensor activation_output_scale = alloc_tensor({intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states.device()); - Tensor gemm2_output = - alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states.device()); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); @@ -257,7 +301,8 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( // setup workspace workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); - workspace.total_max_padded_tokens = max_num_padded_tokens; + workspace.total_max_padded_tokens = + std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); workspace.ProjUpTileN = tile_tokens_dim; workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); @@ -283,13 +328,6 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( args.output = output.data_ptr(); args.output_scale = nullptr; - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( - args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim, /*useShuffledMatrixA*/ true); - - auto const moeConfigIndex = - moe_runner.getDefaultValidConfigIndex(args.top_k, args.hidden_size, args.intermediate_size, - args.local_num_experts, args.num_tokens); - auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); Tensor workspace_fc1 = alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); @@ -309,16 +347,56 @@ void trtllm_fp8_per_tensor_scale_moe( TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type, - bool enable_pdl) { + bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl, + Array config_index) { auto dtype = hidden_states.dtype(); if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { + using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; + + // Convert PyTorch dtype to TensorRT-LLM dtype + btg::Dtype mDtypeElt; + if (dtype == dl_float16) { + mDtypeElt = btg::Dtype::Fp16; + } else if (dtype == dl_bfloat16) { + mDtypeElt = btg::Dtype::Bfloat16; + } else if (dtype == dl_float8_e4m3fn) { + mDtypeElt = btg::Dtype::E4m3; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; + } + + auto const num_tokens = hidden_states.size(0); + auto const hidden_size = hidden_states.size(1); + bool mUseDeepSeekFp8{false}; // FP8 per-tensor doesn't use DeepSeek FP8 + + std::vector mSupportedTileN = {8, 16, 32, 64, 128}; + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + + // Build runners for all supported tile sizes + std::unordered_map> mRunners; + for (int32_t tile_N : selected_tile_nums) { + // Always use the two-parameter constructor for consistency + mRunners.emplace(tile_N, std::make_unique(mDtypeElt, mUseDeepSeekFp8, tile_N, + /*useShuffledMatrixA*/ true)); + } + + // moeConfigIndex corresponds to pair (tile_N, config) + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + // Autotuner has requested a default or 'fallback' config index + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + local_num_experts, num_tokens); + } + trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, output, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type, - enable_pdl); + routed_scaling_factor, use_routing_scales_on_input, tile_N, routing_method_type, + *mRunners[tile_N], config, enable_pdl); } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype."; } @@ -468,10 +546,6 @@ void trtllm_fp8_block_scale_moe_launcher( routing_logits.device()); // allocate workspace for activation/gemm/finalize kernels - // Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); - // Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size}, - // dl_float8_e4m3fn, hidden_states.device()); Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, hidden_states.device()); Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, @@ -623,16 +697,14 @@ void trtllm_fp8_block_scale_moe_launcher( enable_pdl); } -void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional routing_bias, - TensorView hidden_states, TensorView hidden_states_scale, - TensorView gemm1_weights, TensorView gemm1_weights_scale, - TensorView gemm2_weights, TensorView gemm2_weights_scale, - TensorView output, int64_t num_experts, int64_t top_k, - Optional n_group, Optional topk_group, - int64_t intermediate_size, int64_t local_expert_offset, - int64_t local_num_experts, Optional routed_scaling_factor, - int64_t tile_tokens_dim, int64_t routing_method_type, - bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl) { +void trtllm_fp8_block_scale_moe( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, + TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, + int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, + int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, + Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, bool enable_pdl, Array config_index) { auto dtype = hidden_states.dtype(); if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; @@ -643,24 +715,36 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) << "the value of weight_layout is not recognized"; - // Properly initialize the runner using make_unique like in the original code - auto mRunner = std::make_unique( - mDtypeElt, mUseDeepSeekFp8, tile_tokens_dim, use_shuffled_weight, - static_cast(weight_layout)); - - // Always use fallback config (equivalent to moeConfigIndex == -1 case from original code) auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); - int64_t moeConfigIndex = mRunner->getDefaultValidConfigIndex( - top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); + std::vector mSupportedTileN = {8, 16, 32, 64}; + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + + // Build runners for all supported tile sizes + std::unordered_map> mRunners; + for (int32_t tile_N : selected_tile_nums) { + mRunners.emplace(tile_N, std::make_unique( + mDtypeElt, mUseDeepSeekFp8, tile_N, use_shuffled_weight, + static_cast(weight_layout))); + } + + // moeConfigIndex corresponds to pair (tile_N, config) + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + // Autotuner has requested a default or 'fallback' config index + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + local_num_experts, num_tokens); + } trtllm_fp8_block_scale_moe_launcher( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, output, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, tile_tokens_dim, routing_method_type, *mRunner, moeConfigIndex, - enable_pdl); + routed_scaling_factor, tile_N, routing_method_type, *mRunners[tile_N], config, enable_pdl); } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported hidden state dtype."; } @@ -845,10 +929,6 @@ Array trtllm_fp4_block_scale_moe_launcher( Tensor permuted_idx_to_token_idx = alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device()); - // Tensor expert_weights = alloc_tensor( - // {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states.device()); - // Tensor expert_indexes = alloc_tensor( - // {args.num_tokens, args.top_k}, dl_int32, hidden_states.device(); int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); Tensor expert_count_histogram = alloc_tensor({size_of_expert_count_histogram}, dl_int32, hidden_states.device()); @@ -858,10 +938,6 @@ Array trtllm_fp4_block_scale_moe_launcher( // allocate workspace for activation/gemm/finalize kernels auto const gemm1_output_hidden = dtype_act == btg::Dtype::E2m1 ? intermediate_size / 2 : intermediate_size; - // Tensor gemm1_output = alloc_tensor( - // {max_num_padded_tokens, gemm1_output_hidden}, - // dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, - // hidden_states.device()); Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden}, dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, hidden_states.device()); @@ -1101,8 +1177,8 @@ Array trtllm_fp4_block_scale_moe( Optional output2_scales_scalar, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - int64_t tile_tokens_dim, int64_t routing_method_type, bool do_finalize, bool enable_pdl, - int64_t gated_act_type, TensorView output, int64_t config_index) { + int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type, + TensorView output, Array config_index) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; int const num_tokens = hidden_states.size(0); @@ -1148,55 +1224,115 @@ Array trtllm_fp4_block_scale_moe( } bool mUseDeepSeekFp8{false}; // FP4 doesn't use DeepSeek FP8 - // Properly initialize the runner using make_unique like in the original code - auto mRunner = std::make_unique( - mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(gated_act_type), /*useShuffledMatrixA*/ true); - - if (config_index == -1) { - config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - local_num_experts, num_tokens); + std::vector mSupportedTileN = {8, 16, 32, 64}; + if (mDtypeAct != btg::Dtype::Bfloat16) { + mSupportedTileN.push_back(128); + } + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + // Build runners for all supported tile sizes + std::unordered_map> mRunners; + for (int32_t tile_N : selected_tile_nums) { + mRunners.emplace(tile_N, + std::make_unique(mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, tile_N, + static_cast(gated_act_type), + /*useShuffledMatrixA*/ true)); } + // moeConfigIndex corresponds to pair (tile_N, config) + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + // Autotuner has requested a default or 'fallback' config index + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + config = mRunners[tile_N]->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + local_num_experts, num_tokens); + } return trtllm_fp4_block_scale_moe_launcher( routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group, - intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights, - config_index, enable_pdl, output); + intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_N, + routing_method_type, do_finalize, *mRunners[tile_N], mDtypeAct, mDtypeWeights, config, + enable_pdl, output); } -int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_, - int64_t const dtype_weights_, bool const useDeepSeekFp8, - int64_t const top_k, int64_t const hidden_size, - int64_t const intermediate_size, +int64_t trtllm_get_default_moe_configs(int64_t const dtype_act_, int64_t const dtype_weights_, + bool const useDeepSeekFp8, int64_t const top_k, + int64_t const hidden_size, int64_t const intermediate_size, int64_t const num_local_experts, int64_t const gated_act_type, int64_t const num_tokens) { auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( - dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(gated_act_type), /*useShuffledMatrixA*/ true); - return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens); + std::vector supported_tile_nums = {8, 16, 32, 64}; + // Check if we should add tile size 128 + bool is_fp4_without_bf16_act = + (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) && + dtype_act != btg::Dtype::Bfloat16; + bool is_fp8_per_tensor = + dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; + + if (is_fp4_without_bf16_act || is_fp8_per_tensor) { + supported_tile_nums.push_back(128); + } + std::set selected_tile_nums = + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + + std::unique_ptr moe_runner = + std::make_unique( + dtype_act, dtype_weights, useDeepSeekFp8, *selected_tile_nums.begin(), + static_cast(gated_act_type), /*useShuffledMatrixA*/ true); + + return moe_runner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); } -Array trtllm_get_valid_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_, - int64_t const dtype_weights_, bool const useDeepSeekFp8, - int64_t const top_k, int64_t const hidden_size, - int64_t const intermediate_size, - int64_t const num_local_experts, - int64_t const gated_act_type, - int64_t const num_tokens) { +Array> trtllm_get_valid_moe_configs( + int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, + int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, + int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight, + int64_t const weight_layout, int64_t const num_tokens) { + // returns (tile_N, config) + Array> valid_configs; auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( - dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(gated_act_type), /*useShuffledMatrixA*/ true); - return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, - num_tokens); + std::vector supported_tile_nums = {8, 16, 32, 64}; + // Check if we should add tile size 128 + bool is_fp4_without_bf16_act = + (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) && + dtype_act != btg::Dtype::Bfloat16; + bool is_fp8_per_tensor = + dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3 && !useDeepSeekFp8; + + if (is_fp4_without_bf16_act || is_fp8_per_tensor) { + supported_tile_nums.push_back(128); + } + std::set selected_tile_nums = + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + + for (int32_t tile_N : selected_tile_nums) { + std::unique_ptr moe_runner; + + if (dtype_weights == btg::Dtype::E4m3 && dtype_act == btg::Dtype::E4m3) { + // FP8 block scale MOE runner + moe_runner = std::make_unique( + dtype_weights, useDeepSeekFp8, tile_N, use_shuffled_weight, + static_cast(weight_layout)); + } else { + // FP4 block scale MOE runner + moe_runner = std::make_unique( + dtype_act, dtype_weights, useDeepSeekFp8, tile_N, + static_cast(gated_act_type), + /*useShuffledMatrixA*/ true); + } + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } + } + return valid_configs; } namespace trtllm_cubin_loader { diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 5f0e33ccf9..c91878ca0e 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -18,7 +18,6 @@ from enum import IntEnum from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union - import torch from ..autotuner import ( @@ -45,7 +44,6 @@ device_support_pdl, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices, - calculate_tile_tokens_dim, register_custom_op, register_fake_op, ) @@ -915,8 +913,9 @@ def __init__( use_deepseek_fp8: bool, hidden_size: int, intermediate_size: int, - gated_act_type: int, - tile_tokens_dim: Optional[int] = None, + gated_act_type: int = GatedActType.SwiGlu, + use_shuffled_weight: bool = False, + weight_layout: int = WeightLayout.MajorK, ): self.num_local_experts = num_local_experts self.top_k = top_k @@ -926,8 +925,18 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gated_act_type = gated_act_type - self.tile_tokens_dim = tile_tokens_dim + self.gated_act_type = GatedActType(gated_act_type) + self.use_shuffled_weight = use_shuffled_weight + self.weight_layout = WeightLayout(weight_layout) + if ( + not self.use_shuffled_weight + or self.weight_layout != WeightLayout.MajorK + ): + assert ( + self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3 + ), ( + "use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale" + ) def get_valid_tactics( self, @@ -943,18 +952,8 @@ def get_valid_tactics( *extra_inputs, ) = inputs num_tokens = routing_logits.shape[0] - tile_tokens_dim = ( - calculate_tile_tokens_dim( - num_tokens, - self.num_local_experts, - self.top_k, - 64 if self.dtype_act == DtypeTrtllmGen.Bfloat16 else 128, - ) - if self.tile_tokens_dim is None - else self.tile_tokens_dim - ) + instance_key = ( - tile_tokens_dim, self.dtype_act, self.dtype_weights, self.use_deepseek_fp8, @@ -963,6 +962,8 @@ def get_valid_tactics( self.intermediate_size, self.num_local_experts, self.gated_act_type, + self.use_shuffled_weight, + self.weight_layout, num_tokens, ) if instance_key not in MoERunner.valid_tactics_dict: @@ -992,16 +993,6 @@ def forward( *extra_inputs, ) = inputs num_tokens = routing_logits.shape[0] - tile_tokens_dim = ( - calculate_tile_tokens_dim( - num_tokens, - self.num_local_experts, - self.top_k, - 64 if self.dtype_act == DtypeTrtllmGen.Bfloat16 else 128, - ) - if self.tile_tokens_dim is None - else self.tile_tokens_dim - ) extra_input_idx = 0 if trtllm_gen_dtype_has_scale(self.dtype_act): @@ -1026,42 +1017,106 @@ def forward( hidden_states_scale.dim() == 2 and hidden_states_scale.shape[0] == num_tokens ), "hidden_states_scale's first dimension must be batch size" - # TODO(siyuan): support fp8 - moe_op.trtllm_fp4_block_scale_moe( - routing_logits, - topk_ids, - expert_weights, - kwargs["routing_bias"], - hidden_states, - hidden_states_scale, # hidden_states_scale - kwargs["gemm1_weights"], - kwargs["gemm1_weights_scale"], - kwargs["gemm1_bias"], - kwargs["gemm1_alpha"], - kwargs["gemm1_beta"], - kwargs["gemm1_clamp_limit"], - kwargs["gemm2_weights"], - kwargs["gemm2_weights_scale"], - kwargs["gemm2_bias"], - kwargs["output1_scale_scalar"], - kwargs["output1_scale_gate_scalar"], - kwargs["output2_scale_scalar"], - kwargs["num_experts"], - self.top_k, - kwargs["n_group"], - kwargs["topk_group"], - self.intermediate_size, - kwargs["local_expert_offset"], - self.num_local_experts, - kwargs["routed_scaling_factor"], - tile_tokens_dim, - kwargs["routing_method_type"], - kwargs["enable_pdl"], - kwargs["do_finalize"], - self.gated_act_type, - output, - tactic, - ) + # Choose the appropriate operation based on data types + if ( + self.dtype_act == DtypeTrtllmGen.E4m3 + and self.dtype_weights == DtypeTrtllmGen.E4m3 + ): + # FP8 operations + if self.use_deepseek_fp8: + # FP8 block scale + current_num_tokens = hidden_states.shape[0] + current_hidden_size = hidden_states.shape[1] + current_hidden_states_scale = torch.full( + (current_hidden_size // 128, current_num_tokens), + 2.0, + dtype=torch.float, + device=hidden_states.device, + ) + moe_op.trtllm_fp8_block_scale_moe( + routing_logits, + kwargs["routing_bias"], + hidden_states, + current_hidden_states_scale, + kwargs["gemm1_weights"], + kwargs["gemm1_weights_scale"], + kwargs["gemm2_weights"], + kwargs["gemm2_weights_scale"], + output, + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["routing_method_type"], + kwargs["use_shuffled_weight"], + kwargs["weight_layout"], + kwargs["enable_pdl"], + [-1, -1] if tactic == -1 else tactic, + ) + else: + # FP8 per tensor scale + moe_op.trtllm_fp8_per_tensor_scale_moe( + routing_logits, + kwargs["routing_bias"], + hidden_states, + kwargs["gemm1_weights"], + kwargs["output1_scales_scalar"], + kwargs["output1_scales_gate_scalar"], + kwargs["gemm2_weights"], + kwargs["output2_scales_scalar"], + output, + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["use_routing_scales_on_input"], + kwargs["routing_method_type"], + kwargs["enable_pdl"], + [-1, -1] if tactic == -1 else tactic, + ) + else: + moe_op.trtllm_fp4_block_scale_moe( + routing_logits, + topk_ids, + expert_weights, + kwargs["routing_bias"], + hidden_states, + hidden_states_scale, # hidden_states_scale + kwargs["gemm1_weights"], + kwargs["gemm1_weights_scale"], + kwargs["gemm1_bias"], + kwargs["gemm1_alpha"], + kwargs["gemm1_beta"], + kwargs["gemm1_clamp_limit"], + kwargs["gemm2_weights"], + kwargs["gemm2_weights_scale"], + kwargs["gemm2_bias"], + kwargs["output1_scale_scalar"], + kwargs["output1_scale_gate_scalar"], + kwargs["output2_scale_scalar"], + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["routing_method_type"], + kwargs["enable_pdl"], + kwargs["do_finalize"], + self.gated_act_type, + output, + [-1, -1] if tactic == -1 else tactic, + ) @classmethod @functools.lru_cache(maxsize=None) @@ -1111,14 +1166,67 @@ def trtllm_fp8_per_tensor_scale_moe_op( local_num_experts: int, routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, - tile_tokens_dim: int = 8, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) + # Use AutoTuner to select the best tactic + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + + # Create workspace buffers output = torch.empty( - hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + expert_weights = torch.empty( + num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device + ) + + dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation + dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights + + moe_runner = MoERunner( + top_k=top_k, + num_local_experts=local_num_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + use_deepseek_fp8=False, # per_tensor mode + hidden_size=hidden_size, + intermediate_size=intermediate_size, + weight_layout=WeightLayout.MajorK, + use_shuffled_weight=True, + ) + + inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_fp8_per_tensor_scale_moe", + [moe_runner], + MoERunner.tuning_config_no_hidden_states_scales, # FP8 per-tensor doesn't use hidden_states_scale + inputs, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + output1_scales_scalar=output1_scales_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + gemm2_weights=gemm2_weights, + output2_scales_scalar=output2_scales_scalar, + num_experts=num_experts, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + routing_method_type=routing_method_type, + enable_pdl=enable_pdl, ) # Call the C++ function moe_op.trtllm_fp8_per_tensor_scale_moe( @@ -1140,9 +1248,9 @@ def trtllm_fp8_per_tensor_scale_moe_op( local_num_experts, routed_scaling_factor, use_routing_scales_on_input, - tile_tokens_dim, routing_method_type, enable_pdl, + [-1, -1] if tactic == -1 else tactic, ) return output @@ -1165,7 +1273,6 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( local_num_experts: int, routed_scaling_factor: Optional[float], use_routing_scales_on_input: bool, - tile_tokens_dim: int = 8, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, ): @@ -1196,15 +1303,78 @@ def trtllm_fp8_block_scale_moe_op( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: int, routing_method_type: int, use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) + # Use AutoTuner to select the best tactic - follow FP4 pattern exactly + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + + # Create workspace buffers + output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + expert_weights = torch.empty( + num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device + ) + + dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation + dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights + + moe_runner = MoERunner( + top_k=top_k, + num_local_experts=local_num_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + use_deepseek_fp8=True, # block_scale mode + hidden_size=hidden_size, + intermediate_size=intermediate_size, + weight_layout=weight_layout, + use_shuffled_weight=use_shuffled_weight, + ) + + inputs = [ + output, + routing_logits, + topk_ids, + expert_weights, + hidden_states, + hidden_states_scale, + ] + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_fp8_block_scale_moe", + [moe_runner], + MoERunner.tuning_config_with_hidden_states_scales, # FP8 block-scale uses hidden_states_scale + inputs, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + num_experts=num_experts, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=routing_method_type, + use_shuffled_weight=use_shuffled_weight, + weight_layout=weight_layout, + enable_pdl=enable_pdl, + ) # Call the C++ function for block scale MoE moe_op.trtllm_fp8_block_scale_moe( routing_logits, @@ -1224,11 +1394,11 @@ def trtllm_fp8_block_scale_moe_op( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, use_shuffled_weight, weight_layout, enable_pdl, + [-1, -1] if tactic == -1 else tactic, ) return output @@ -1252,7 +1422,6 @@ def _fake_trtllm_fp8_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: int = 8, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, @@ -1294,7 +1463,6 @@ def trtllm_fp4_block_scale_moe_op( local_expert_offset: int, num_local_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int], routing_method_type: int, do_finalize: bool, enable_pdl: Optional[bool] = None, @@ -1340,13 +1508,6 @@ def trtllm_fp4_block_scale_moe_op( dtype_weights = deduce_trtllm_gen_tensor_dtype( gemm1_weights, gemm1_weights_scale ) - if tile_tokens_dim is None: - tile_tokens_dim = calculate_tile_tokens_dim( - num_tokens, - num_experts, - top_k, - max_tile_tokens_dim=64 if dtype_act == DtypeTrtllmGen.Bfloat16 else 128, - ) moe_runner = MoERunner( top_k=top_k, num_local_experts=num_local_experts, @@ -1356,9 +1517,8 @@ def trtllm_fp4_block_scale_moe_op( hidden_size=hidden_size, intermediate_size=intermediate_size, gated_act_type=gated_act_type, - # NOTE(siyuan): do not fix the tile_tokens_dim to let tunnable runner decide the tile_tokens_dim itself. - # however, when the user chooses a different heuristic for tile_tokens_dim, the autotuner will fail to find the correct cached tactics. - # tile_tokens_dim=tile_tokens_dim, + weight_layout=WeightLayout.MajorK, + use_shuffled_weight=True, ) tunning_config = ( MoERunner.tuning_config_no_hidden_states_scales @@ -1434,13 +1594,12 @@ def trtllm_fp4_block_scale_moe_op( local_expert_offset, num_local_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, enable_pdl, gated_act_type, output, - tactic, + [-1, -1] if tactic == -1 else tactic, ) if do_finalize: return [output] @@ -1480,7 +1639,6 @@ def _fake_trtllm_fp4_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int], routing_method_type: int, do_finalize: bool, enable_pdl: bool, @@ -1549,6 +1707,12 @@ def trtllm_fp8_per_tensor_scale_moe( Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp8_per_tensor_scale_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) return get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, @@ -1567,7 +1731,6 @@ def trtllm_fp8_per_tensor_scale_moe( local_num_experts, routed_scaling_factor, use_routing_scales_on_input, - tile_tokens_dim, routing_method_type, enable_pdl, ) @@ -1590,7 +1753,7 @@ def trtllm_fp8_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: int = 8, + tile_tokens_dim: Optional[int] = None, routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, @@ -1621,6 +1784,12 @@ def trtllm_fp8_block_scale_moe( Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp8_block_scale_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) output = torch.empty( hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device ) @@ -1642,7 +1811,6 @@ def trtllm_fp8_block_scale_moe( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, use_shuffled_weight, weight_layout, @@ -1675,7 +1843,7 @@ def trtllm_fp4_block_scale_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int] = None, + tile_tokens_dim: Optional[int], routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, @@ -1726,7 +1894,7 @@ def trtllm_fp4_block_scale_moe( local_expert_offset (int): Offset of local experts in global expert space local_num_experts (int): Number of experts handled by this device routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods) - tile_tokens_dim (int): Tile dimension for tokens (default: 8) + tile_tokens_dim (Optional[int]): Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type (int): Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) @@ -1745,6 +1913,12 @@ def trtllm_fp4_block_scale_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp4_block_scale_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( routing_logits, None, @@ -1772,7 +1946,6 @@ def trtllm_fp4_block_scale_moe( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, enable_pdl, @@ -1807,7 +1980,7 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset: int, local_num_experts: int, routed_scaling_factor: Optional[float], - tile_tokens_dim: Optional[int] = None, + tile_tokens_dim: Optional[int], routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, @@ -1860,7 +2033,7 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset (int): Offset of local experts in global expert space local_num_experts (int): Number of experts handled by this device routed_scaling_factor (Optional[float]): Scaling factor for routing (can be None for some routing methods) - tile_tokens_dim (int): Tile dimension for tokens (default: 8) + tile_tokens_dim (Optional[int]): Tile dimension for tokens (default: None, will be deprecated in the future) routing_method_type (int): Type of routing method to use (default: 0) - 0: Default (Softmax -> TopK) - 1: Renormalize (TopK -> Softmax) @@ -1879,6 +2052,12 @@ def trtllm_fp4_block_scale_routed_moe( List[torch.Tensor]: List of output tensors. If do_finalize=True, returns the final MoE output. Otherwise, returns intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. """ + if tile_tokens_dim is not None: + logger.warning_once( + "tile_tokens_dim in trtllm_fp4_block_scale_routed_moe is planned for deprecation " + "in a future release. Please remove it from your code as tile_tokens_dim will no " + "longer be supported after v0.5.0." + ) return get_trtllm_moe_sm100_module().trtllm_fp4_block_scale_moe( None, topk_ids, @@ -1906,7 +2085,6 @@ def trtllm_fp4_block_scale_routed_moe( local_expert_offset, local_num_experts, routed_scaling_factor, - tile_tokens_dim, routing_method_type, do_finalize, enable_pdl, diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index e7dec73723..27034a4054 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,10 +1,11 @@ import dataclasses +import functools import logging import os from contextlib import nullcontext from datetime import datetime from pathlib import Path -from typing import Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union, Hashable import tvm_ffi from filelock import FileLock @@ -60,6 +61,33 @@ def __init__(self, name): ) ) + def debug_once(self, msg: str, *args: Hashable) -> None: + """ + As [`debug`][logging.Logger.debug], but subsequent calls with + the same message are silently dropped. + """ + self._print_once(self.debug, msg, *args) + + def info_once(self, msg: str, *args: Hashable) -> None: + """ + As [`info`][logging.Logger.info], but subsequent calls with + the same message are silently dropped. + """ + self._print_once(self.info, msg, *args) + + def warning_once(self, msg: str, *args: Hashable) -> None: + """ + As [`warning`][logging.Logger.warning], but subsequent calls with + the same message are silently dropped. + """ + self._print_once(self.warning, msg, *args) + + @functools.lru_cache(maxsize=None) + def _print_once(self, log_method, msg: str, *args: Hashable) -> None: + """Helper method to log messages only once per unique (msg, args) combination.""" + # Note: stacklevel=3 to show the caller's location, not this helper method + log_method(msg, *args, stacklevel=3) + logger = FlashInferJITLogger("flashinfer.jit") diff --git a/tests/conftest.py b/tests/conftest.py index dc81dc0db2..768eec8fa3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -137,11 +137,11 @@ def is_cuda_oom_error_str(e: str) -> bool: return "CUDA" in e and "out of memory" in e -@pytest.hookimpl(tryfirst=True) +@pytest.hookimpl(wrapper=True) def pytest_runtest_call(item): # skip OOM error and missing JIT cache errors try: - item.runtest() + yield except (torch.cuda.OutOfMemoryError, RuntimeError) as e: if isinstance(e, torch.cuda.OutOfMemoryError) or is_cuda_oom_error_str(str(e)): pytest.skip("Skipping due to OOM") diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index a093d4c0aa..df19e00310 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -14,10 +14,10 @@ limitations under the License. """ +import pytest from abc import ABC, abstractmethod from enum import IntEnum from typing import Dict -import pytest import torch from cuda.bindings import runtime from torch.nn import functional as F @@ -45,7 +45,7 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from flashinfer.utils import calculate_tile_tokens_dim, get_compute_capability +from flashinfer.utils import get_compute_capability def check_cuda(err): @@ -202,7 +202,7 @@ def _run_moe_computation(self, runtime_args): local_expert_offset=0, local_num_experts=self.config["num_experts"], routed_scaling_factor=self.config["routed_scaling"], - tile_tokens_dim=self.config["tile_tokens_dim"], + tile_tokens_dim=None, routing_method_type=self.config["routing_method_type"], gated_act_type=self.config["gated_act_type"], do_finalize=True, @@ -549,7 +549,6 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] gated_act_type = kwargs["gated_act_type"] routing_method_type = kwargs["routing_method_type"] - tile_tokens_dim = kwargs["tile_tokens_dim"] # Create CUDA graph configuration config = { @@ -560,7 +559,6 @@ def call_moe( "top_k_groups": top_k_groups, "intermediate_size": intermediate_size, "routed_scaling": routed_scaling, - "tile_tokens_dim": tile_tokens_dim, "gated_act_type": gated_act_type, "routing_method_type": routing_method_type, } @@ -727,8 +725,8 @@ def prepare_static_weights_for_kernel( tmp_weights2 = convert_to_block_layout(tmp_weights2, block_k) gemm1_weights_fp8_shuffled.append(tmp_weights1) - gemm2_weights_fp8_shuffled.append(tmp_weights2) + kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view( torch.float8_e4m3fn ) @@ -761,7 +759,6 @@ def call_moe( intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] - tile_tokens_dim = kwargs["tile_tokens_dim"] enable_pdl = kwargs.get("enable_pdl") hidden_states_scale = kwargs["hidden_states_scale"] hidden_states_quant = kwargs["hidden_states_quant"] @@ -772,29 +769,31 @@ def call_moe( "NaN detected in hidden_states_fp8" ) - output = trtllm_fp8_block_scale_moe( - expert_logits, - routing_bias, - hidden_states_fp8, - hidden_states_scale, - static_data["gemm1_weights"], - static_data["gemm1_scales"], - static_data["gemm2_weights"], - static_data["gemm2_scales"], - num_experts, - top_k, - n_groups, - top_k_groups, - intermediate_size, - 0, - num_experts, - routed_scaling, - tile_tokens_dim, - routing_method_type, - use_shuffled_weight=static_data["use_shuffled_weight"], - weight_layout=static_data["weight_layout"], - enable_pdl=enable_pdl, - ) + # Use autotuner for optimal kernel selection + with autotune(True): + output = trtllm_fp8_block_scale_moe( + expert_logits, + routing_bias, + hidden_states_fp8, + hidden_states_scale, + static_data["gemm1_weights"], + static_data["gemm1_scales"], + static_data["gemm2_weights"], + static_data["gemm2_scales"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + routed_scaling, + None, + routing_method_type, + use_shuffled_weight=static_data["use_shuffled_weight"], + weight_layout=static_data["weight_layout"], + enable_pdl=enable_pdl, + ) return output.to(torch.float) @@ -937,39 +936,40 @@ def call_moe( intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] - tile_tokens_dim = kwargs["tile_tokens_dim"] # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( hidden_states_orig, hidden_states_scale_global ) - output = trtllm_fp8_per_tensor_scale_moe( - ( - expert_logits.to(torch.bfloat16) - if routing_method_type == RoutingMethodType.Llama4 - else expert_logits - ), - routing_bias, - hidden_states_fp8, - static_data["gemm1_weights"], - static_data["scale_c_fc1"], - static_data["scale_gate_fc1"], - static_data["gemm2_weights"], - static_data["scale_c_fc2"], - num_experts, - top_k, - n_groups, - top_k_groups, - intermediate_size, - 0, - num_experts, - routed_scaling, - routing_method_type - == RoutingMethodType.Llama4, # Use_routing_scales_on_input - tile_tokens_dim, - routing_method_type, - ) + # Use autotuner for optimal kernel selection + with autotune(True): + output = trtllm_fp8_per_tensor_scale_moe( + ( + expert_logits.to(torch.bfloat16) + if routing_method_type == RoutingMethodType.Llama4 + else expert_logits + ), + routing_bias, + hidden_states_fp8, + static_data["gemm1_weights"], + static_data["scale_c_fc1"], + static_data["scale_gate_fc1"], + static_data["gemm2_weights"], + static_data["scale_c_fc2"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + routed_scaling, + routing_method_type + == RoutingMethodType.Llama4, # Use_routing_scales_on_input + None, + routing_method_type, + ) return output.to(torch.float) @@ -985,8 +985,6 @@ def get_tolerances(self): # ==================================================================================== # Quantizer Factory # ==================================================================================== - - def get_moe_impl(quant_mode: QuantMode): """Factory function to get the appropriate MoE implementation.""" if quant_mode == QuantMode.FP8_BLOCK_SCALE: @@ -1815,7 +1813,6 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "intermediate_size": args.intermediate_size, "routed_scaling": kwargs["routed_scaling"], "routing_method_type": kwargs["routing_method_type"], - "tile_tokens_dim": kwargs["tile_tokens_dim"], "do_finalize": True, "gated_act_type": args.gated_act_type, "hidden_states_scale": args.hidden_states_scale, @@ -1837,203 +1834,16 @@ def cache_permute_indices(): return _cache_permute_indices -@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) -@pytest.mark.parametrize("hidden_size", [1024, 8192]) -@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) -@pytest.mark.parametrize( - "moe_impl", - [ - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), - pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), - pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), - ], -) -@pytest.mark.parametrize( - "routing_config", - [ - pytest.param( - { - "num_experts": 384, - "top_k": 8, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], - }, - id="kimi_k2", - ), - pytest.param( - { - "num_experts": 256, - "top_k": 8, - "padding": 8, - "n_groups": 8, - "top_k_groups": 4, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], - }, - id="DSv3", - ), - pytest.param( - { - "num_experts": 72, - "top_k": 6, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [ - FP4Moe, - FP8BlockScaleMoe, - ], - }, - id="DSLite", - ), - pytest.param( - { - "num_experts": 256, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP8PerTensorMoe, FP4Moe], - }, - id="Renorm", - marks=pytest.mark.skip( - reason="Disabled for testing speed - similar to RenormalizeNaive" - ), - ), - pytest.param( - { - "num_experts": 128, - "top_k": 10, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.Renormalize, - "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], - }, - id="Qwen3_next", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 8, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.RenormalizeNaive, - "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], - }, - id="RenormNaive", - ), - pytest.param( - { - "num_experts": 16, - "top_k": 2, - "padding": 8, - "n_groups": None, - "top_k_groups": None, - "routed_scaling": None, - "has_routing_bias": False, - "routing_method_type": RoutingMethodType.TopK, - "compatible_moe_impls": [FP4Moe], - }, - id="TopK", - ), - pytest.param( - { - "num_experts": 128, - "top_k": 1, - "padding": 8, - "n_groups": 0, - "top_k_groups": 0, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.Llama4, - "compatible_moe_impls": [FP8PerTensorMoe], - }, - id="Llama4", - ), - ], -) -@pytest.mark.parametrize( - "weight_processing", - [ - pytest.param( - { - "use_shuffled_weight": False, - "layout": WeightLayout.MajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="NoShuffle_MajorK", - ), - pytest.param( - { - "use_shuffled_weight": True, - "layout": WeightLayout.MajorK, - "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], - }, - id="Shuffled_MajorK", - ), - pytest.param( - { - "use_shuffled_weight": True, - "layout": WeightLayout.BlockMajorK, - "compatible_moe_impls": [FP8BlockScaleMoe], - }, - id="Shuffled_BlockMajorK", - ), - ], -) -@pytest.mark.parametrize( - "gated_act_type", - [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), - ], -) -def test_moe_quantization_classes( - num_tokens, - hidden_size, - intermediate_size, +def skip_checks( moe_impl, routing_config, weight_processing, gated_act_type, - cache_permute_indices, + num_tokens, + hidden_size, + intermediate_size, ): - """ - Test MoE implementations using separated quantization workflow. - - This test demonstrates the clean separation between: - - Static weight quantization (done offline) - - Dynamic input quantization (done at runtime) - - Each quantization class clearly shows which precision is being used. - """ + """Common skip logic for all tests.""" compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] not in [10]: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") @@ -2044,14 +1854,12 @@ def test_moe_quantization_classes( or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): - # GeGlu is only supported for FP4Moe FP4_NVFP4_NVFP4 and TopK routing pytest.skip( f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" ) elif gated_act_type == GatedActType.SwiGlu and ( hidden_size > 1024 or intermediate_size > 1024 ): - # Skip some tests for SwiGlu for testing speed pytest.skip( f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" ) @@ -2070,6 +1878,10 @@ def test_moe_quantization_classes( pytest.skip( f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" ) + if intermediate_size not in routing_config["compatible_intermediate_size"]: + pytest.skip( + f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" + ) # TODO(jimmzhou): enable MxFP4xBf16 on SM103 if ( @@ -2082,6 +1894,30 @@ def test_moe_quantization_classes( "Note(jimmzhou): Make MxFP4xBf16 nonfunctional on SM103 to avoid B200 regression" ) + +def run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Common test logic for all routing methods.""" + skip_checks( + moe_impl, + routing_config, + weight_processing, + gated_act_type, + num_tokens, + hidden_size, + intermediate_size, + ) + + torch.cuda.synchronize() + moe_impl._cache_permute_indices = cache_permute_indices seed = 0 @@ -2096,17 +1932,6 @@ def test_moe_quantization_classes( num_experts = routing_config["num_experts"] routing_method_type = routing_config["routing_method_type"] - tile_tokens_dim = calculate_tile_tokens_dim( - num_tokens, - num_experts, - top_k, - max_tile_tokens_dim=128 - if ( - type(moe_impl) is FP4Moe and moe_impl.quant_mode != QuantMode.FP4_MXFP4_Bf16 - ) - else 64, - ) - # Validation checks assert top_k <= num_experts assert top_k <= 10 @@ -2117,15 +1942,12 @@ def test_moe_quantization_classes( assert num_experts % 4 == 0 assert top_k < (top_k_groups * num_experts / n_groups) - # Create test data based on routing method and quantization mode - # Different kernels have different dtype requirements for routing logits + # Create test data based on routing method if routing_method_type == RoutingMethodType.DeepSeekV3: - # DeepSeekV3 uses float for routing logits expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( torch.float ) else: - # Other routing methods (Renormalize, RenormalizeNaive, Llama4) use bfloat16 expert_logits = torch.randn((num_tokens, num_experts), device="cuda").to( torch.bfloat16 ) @@ -2191,12 +2013,12 @@ def test_moe_quantization_classes( f"Routing method {routing_method_type} not implemented" ) - # 1. Quantize weights offline (static, done once) + compute global scale factors + # 1. Quantize weights offline weights_data = moe_impl.quantize_weights( gemm1_weights, gemm2_weights, hidden_states ) - # 2. Quantize inputs at runtime (dynamic, done per inference) using pre-computed scales + # 2. Quantize inputs at runtime inputs_data = moe_impl.quantize_inputs( hidden_states, weights_data["hidden_states_scale_global"] ) @@ -2227,14 +2049,13 @@ def test_moe_quantization_classes( gated_act_type, ) - # Compute reference output using the moe_impl + # Compute reference output output_dequant_reference, args_dequant = moe_impl.compute_reference(args) - # Validate that reference computation succeeded if output_dequant_reference is None: pytest.fail("Reference computation failed to produce output") - # Compute actual output using the moe_impl + # Compute actual output output_dequant_actual = moe_impl.compute_production( args_dequant, args, @@ -2247,15 +2068,12 @@ def test_moe_quantization_classes( top_k_groups=top_k_groups, routed_scaling=routed_scaling, routing_method_type=routing_method_type, - tile_tokens_dim=tile_tokens_dim, weight_processing=weight_processing, enable_pdl=True, - hidden_states_quant=inputs_data[ - "hidden_states" - ], # NOTE(yingyi): only for fp8 block scale for now, refactor later + hidden_states_quant=inputs_data["hidden_states"], ) - # Compare outputs using moe_impl-specific tolerances + # Compare outputs tolerances = moe_impl.get_tolerances() check_accuracy( output_dequant_reference, @@ -2264,3 +2082,363 @@ def test_moe_quantization_classes( rtol=tolerances["rtol"], percent=tolerances["percent"], ) + + +# Test: DeepSeekV3 routing +@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 384, + "top_k": 8, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], + }, + id="kimi_k2", + ), + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [512, 1024, 2048], + }, + id="DSv3", + ), + pytest.param( + { + "num_experts": 72, + "top_k": 6, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], + "compatible_intermediate_size": [384, 768], + }, + id="DSLite", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": False, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="NoShuffle_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.BlockMajorK, + "compatible_moe_impls": [FP8BlockScaleMoe], + }, + id="Shuffled_BlockMajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), + ], +) +def test_deepseekv3_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test DeepSeekV3 routing configurations.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + ) + + +# Test: Renormalize routing +@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), + pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [384, 768, 1024, 2048], + }, + id="Renorm", + marks=pytest.mark.skip(reason="Skip temporary"), + ), + pytest.param( + { + "num_experts": 512, + "top_k": 10, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe], + "compatible_intermediate_size": [512], + }, + id="Qwen3_next", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), + ], +) +def test_renormalize_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test Renormalize routing configurations.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + ) + + +# Test: TopK routing +@pytest.mark.parametrize("num_tokens", [1, 8, 128]) # Limited for GeGlu +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [384, 512, 768, 1024]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), + pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 16, + "top_k": 2, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.TopK, + "compatible_moe_impls": [FP4Moe], + "compatible_intermediate_size": [384, 512, 768, 1024], + }, + id="TopK", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), + ], +) +def test_topk_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test TopK routing configuration.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + ) + + +# Test: Llama4 routing +@pytest.mark.parametrize("num_tokens", [1, 8, 1024]) +@pytest.mark.parametrize("hidden_size", [1024]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +@pytest.mark.parametrize( + "moe_impl", + [ + pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + ], +) +@pytest.mark.parametrize( + "routing_config", + [ + pytest.param( + { + "num_experts": 128, + "top_k": 1, + "padding": 8, + "n_groups": 0, + "top_k_groups": 0, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.Llama4, + "compatible_moe_impls": [FP8PerTensorMoe], + "compatible_intermediate_size": [1024, 2048], + }, + id="Llama4", + ), + ], +) +@pytest.mark.parametrize( + "weight_processing", + [ + pytest.param( + { + "use_shuffled_weight": True, + "layout": WeightLayout.MajorK, + "compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe], + }, + id="Shuffled_MajorK", + ), + ], +) +@pytest.mark.parametrize( + "gated_act_type", + [ + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + ], +) +def test_llama4_routing( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, +): + """Test Llama4 routing configuration with FP8 per-tensor.""" + run_moe_test( + num_tokens, + hidden_size, + intermediate_size, + moe_impl, + routing_config, + weight_processing, + gated_act_type, + cache_permute_indices, + )