diff --git a/benchmarks/README.md b/benchmarks/README.md index 15f561ca29..b66882d38b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -29,6 +29,8 @@ Currently supports testing attention, gemm, fused MOE, normalization, and quanti - `trtllm_fp8_block_scale_moe` - MOE with FP8 quantized weights and block-wise scaling. - `trtllm_fp8_per_tensor_scale_moe` - MOE with FP8 quantized weights and per-tensor scaling. - `cutlass_fused_moe` - CUTLASS fused MoE (base/fp8/nvfp4 variants with optional TP/EP) +- MOE Communication: + - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation. - Norm: - `rmsnorm` - Root Mean Square Layer Normalization. - `rmsnorm_quant` - RMSNorm with FP8 quantized output. @@ -238,6 +240,50 @@ Notes: - FP8 MOE kernels require integer values for group parameters, while FP4 MOE kernels accept optional values. - CUTLASS fused MoE (`cutlass_fused_moe`) ignores `--routing_method`, `--n_group`, and `--topk_group`; it computes routing via softmax+top-k internally from the provided logits. +### MoE Communication Flags (moe_a2a_dispatch_combine) +The `moe_a2a_dispatch_combine` routine benchmarks MoE All-to-All communication for multi-GPU expert-parallel inference. It must be launched with `mpirun`. + +| Flag | Description | +|--------------------------|-------------------------------------------------------------------------------------------------------------| +| `--num_tokens` | Number of tokens per rank (local batch size) | +| `--hidden_size` | Hidden dimension size | +| `--num_experts` | Total number of experts across all ranks | +| `--top_k` | Number of experts to route each token to | +| `--input_dtype` | Data type for hidden states payload: `bfloat16` (default) or `float16` | +| `--quant_dtype` | Quantization format: `fp8` (per-tensor), `nvfp4` (block-scale FP4), `fp8_block_scale` (block-scale FP8) | +| `--real_math` | Run actual MoE kernels instead of fake computation. Requires `--intermediate_size` and `--quant_dtype` to be `nvfp4` or `fp8_block_scale` | +| `--intermediate_size` | Intermediate FFN size. Required if `--real_math` is set | +| `--max_num_tokens` | Max tokens per rank for workspace allocation. Defaults to `--num_tokens` | +| `--validate` | Run correctness validation before benchmarking using deterministic fake MoE | +| `--per_phase_timing` | Enable per-phase timing (dispatch/combine/moe_kernel). Adds slight overhead from CUDA events | +| `--nvtx` | Enable NVTX markers for Nsight Systems profiling | + +**Launch Examples:** +```bash +# Basic (no quantization) +mpirun -np 8 python benchmarks/flashinfer_benchmark.py \ + --routine moe_a2a_dispatch_combine \ + --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 + +# With FP8 quantization +mpirun -np 8 python benchmarks/flashinfer_benchmark.py \ + --routine moe_a2a_dispatch_combine \ + --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 \ + --quant_dtype fp8 + +# With NVFP4 quantization and real MoE kernel +mpirun -np 8 python benchmarks/flashinfer_benchmark.py \ + --routine moe_a2a_dispatch_combine \ + --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 \ + --quant_dtype nvfp4 --real_math --intermediate_size 18432 + +# With validation and per-phase timing +mpirun -np 8 python benchmarks/flashinfer_benchmark.py \ + --routine moe_a2a_dispatch_combine \ + --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 \ + --validate --per_phase_timing +``` + ### Norm Flags | Flag | Description | |--------------------------|-------------------------------------------------------------------------------------------------------------| @@ -301,6 +347,7 @@ Legend: | **trtllm_fp8_block_scale_moe** | | | | | | trtllm | trtllm | | | **trtllm_fp8_per_tensor_scale_moe** | | | | | | trtllm | trtllm | | | **cutlass_fused_moe** | | | | | | cutlass | cutlass | | +| **moe_a2a_dispatch_combine** | | | | | | moe_a2a | moe_a2a | | | **rmsnorm** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | | **rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | | **fused_add_rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | @@ -324,3 +371,4 @@ Backend Legend: - trtllm-native: TensorRT-LLM (out-of-wrapper) - cuda: FlashInfer CUDA kernels - cute-dsl: FlashInfer CuTe-DSL kernels (Blackwell SM10.0+) +- moe_a2a: MoE All-to-All communication (requires mpirun, Blackwell SM10.0+ with MNNVL) diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index ca9214511a..2e4dd7bf06 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -7,16 +7,14 @@ import flashinfer from flashinfer.autotuner import autotune from flashinfer.fused_moe import ( - WeightLayout, trtllm_fp4_block_scale_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, cutlass_fused_moe, - convert_to_block_layout, fused_topk_deepseek, ) from flashinfer.fused_moe.core import RoutingMethodType -from flashinfer import fp4_quantize, shuffle_matrix_a +from flashinfer import fp4_quantize from flashinfer.testing.utils import ( bench_gpu_time, ) @@ -27,6 +25,21 @@ print_perf_metrics, filter_backends_by_compute_capability, ) +from .moe_utils import ( + calculate_fp4_global_scale, + quantize_fp4, + quantize_fp4_batched, + quantize_fp8, + calculate_moe_tflops, + calculate_moe_kernel_bandwidth, + compute_routing, + generate_moe_weights, + add_common_moe_args, + process_fp8_weight_layout, + create_moe_output_scale_scalars, + FLOAT8_E4M3_MAX, + FLOAT4_E2M1_MAX, +) def run_moe_test(args): @@ -62,27 +75,15 @@ def parse_moe_args(line, parser): Returns: Parsed argument namespace """ - parser.add_argument( - "--num_tokens", type=int, required=True, help="Number of input tokens." - ) - parser.add_argument( - "--hidden_size", type=int, required=True, help="Hidden dimension size." - ) + add_common_moe_args(parser) + # Note: num_tokens/hidden_size is added by add_common_moe_args parser.add_argument( "--intermediate_size", type=int, required=True, help="Intermediate dimension size.", ) - parser.add_argument( - "--num_experts", type=int, required=True, help="Total number of experts." - ) - parser.add_argument( - "--top_k", - type=int, - required=True, - help="Number of experts to route to per token.", - ) + # Note: num_experts/top_k is added by add_common_moe_args parser.add_argument( "--n_group", type=int, @@ -160,13 +161,7 @@ def parse_moe_args(line, parser): default=False, help="Whether to use routing scales on input (for Llama4 routing).", ) - parser.add_argument( - "--input_dtype", - type=str, - required=False, - default="bfloat16", - help="Data type of the input hidden states.", - ) + # Note: input_dtype is added by add_common_moe_args parser.add_argument( "--weight_dtype", type=str, @@ -329,178 +324,13 @@ def create_trtllm_moe_test_data( ) # Create weights - always start with bfloat16 for proper quantization - gemm1_weights = torch.randn( - (num_experts, 2 * intermediate_size, hidden_size), - device=device, - dtype=torch.bfloat16, - ) - gemm2_weights = torch.randn( - (num_experts, hidden_size, intermediate_size), - device=device, - dtype=torch.bfloat16, + gemm1_weights, gemm2_weights = generate_moe_weights( + num_experts, hidden_size, intermediate_size, device, dtype=torch.bfloat16 ) return routing_logits, routing_bias, hidden_states, gemm1_weights, gemm2_weights -def calculate_fp4_global_scale_factor(tensor): - """Calculate global scale factor for FP4 quantization.""" - # Calculate as a tensor on the same device - # Using the same formula as in test files: FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - tensor_amax = tensor.abs().max().to(torch.float32) - # FLOAT8_E4M3_MAX = 448, FLOAT4_E2M1_MAX = 6 - global_scale = (448.0 * 6.0) / tensor_amax - return global_scale - - -def quant_fp4_simple(a, a_global_sf, use_ue8m0=False, is_sf_swizzled_layout=True): - """ - Simplified FP4 quantization for benchmarking. - In production, use the actual fp4_quantize function. - """ - sf_vec_size = 16 - - # Use the actual fp4_quantize function from flashinfer - a_fp4, a_sf = fp4_quantize( - a, a_global_sf, sf_vec_size, use_ue8m0, is_sf_swizzled_layout - ) - - return a_fp4, a_sf, a_global_sf - - -def quant_fp4_batches_simple( - a, num_experts, use_ue8m0=False, is_sf_swizzled_layout=True -): - """Simplified FP4 batch quantization for benchmarking.""" - quant_a = [] - sfs = [] - global_sfs = [] - for i in range(num_experts): - # Calculate global scale factor (returns tensor) - a_global_sf = calculate_fp4_global_scale_factor(a[i]) - a_fp4, a_sf, _ = quant_fp4_simple( - a[i], a_global_sf, use_ue8m0, is_sf_swizzled_layout - ) - quant_a.append(a_fp4) - sfs.append(a_sf) - global_sfs.append(a_global_sf) - - result_quant_a = torch.stack(quant_a) - result_sfs = torch.stack(sfs) - result_global_sfs = torch.stack(global_sfs) - - return result_quant_a, result_sfs, result_global_sfs - - -def calculate_moe_tflops( - num_tokens: int, - hidden_size: int, - intermediate_size: int, - num_experts: int, - top_k: int, - time_ms: float, -) -> float: - """ - Calculate TFLOPS for MOE operation. - - MOE computation involves: - 1. First GEMM: [num_tokens, hidden_size] x [num_experts, hidden_size, 2*intermediate_size] - 2. Activation function (SwiGLU gate) - 3. Second GEMM: [num_tokens, intermediate_size] x [num_experts, intermediate_size, hidden_size] - - For each token, we only compute for top_k experts. - - """ - # FLOPS per token per expert (base calculation) - flops_per_token_per_expert = ( - 2 * hidden_size * 2 * intermediate_size # First GEMM - + 2 * intermediate_size * hidden_size # Second GEMM - ) - - total_flops = num_tokens * top_k * flops_per_token_per_expert - tflops = total_flops / (time_ms * 1e-3) / 1e12 # Convert to TFLOPS - return tflops - - -def calculate_moe_bandwidth( - num_tokens: int, - hidden_size: int, - intermediate_size: int, - num_experts: int, - top_k: int, - time_ms: float, - input_dtype: torch.dtype, - weight_dtype: torch.dtype, - input_format: Optional[str] = None, - weight_format: Optional[str] = None, - routing_logits_dtype: Optional[torch.dtype] = torch.float32, - active_experts: Optional[int] = None, - verbose: int = 0, -) -> float: - """ - Calculate memory bandwidth for MOE operation in TB/sec. - - Args: - input_format: Override for input representation; None uses dtype.itemsize - weight_format: Override for weight representation; None uses dtype.itemsize - routing_logits_dtype: Dtype for routing logits memory accounting (default float32) - """ - - # Get effective byte sizes - def get_effective_bytes(dtype: torch.dtype, fmt: Optional[str]) -> float: - if fmt == "nvfp4": - return 0.5 + 1 / 16 - elif fmt == "mxfp4": - return 0.5 + 1 / 32 - elif fmt == "fp8": - return 1.0 - return dtype.itemsize - - input_bytes_per_element = get_effective_bytes(input_dtype, input_format) - weight_bytes_per_element = get_effective_bytes(weight_dtype, weight_format) - - # Input memory: hidden states + routing logits - # Note: routing logits dtype depends on kernel; pass in when known, default float32; None means excluded - routing_logits_bytes = ( - 0 if routing_logits_dtype is None else routing_logits_dtype.itemsize - ) - input_bytes = ( - # Count hidden states once; kernels typically reuse inputs for multiple experts - num_tokens * hidden_size * input_bytes_per_element - + num_tokens * num_experts * routing_logits_bytes - ) - - # Weight memory (reuse weights across tokens by grouping tokens per expert) - # Assume each active expert's weights are read once per run. - weight_bytes_per_expert = ( - 2 * intermediate_size * hidden_size * weight_bytes_per_element # gemm1 - + hidden_size * intermediate_size * weight_bytes_per_element # gemm2 - ) - if active_experts is not None: - num_active_experts = active_experts - else: - num_active_experts = min(num_experts, top_k * num_tokens) - if verbose >= 2: - print(f"[VVERBOSE] num_active_experts = {num_active_experts}") - - weight_bytes = num_active_experts * weight_bytes_per_expert - - # Output memory (typically full precision) - output_bytes = num_tokens * hidden_size * input_dtype.itemsize - - total_bytes = input_bytes + weight_bytes + output_bytes - tb_per_sec = total_bytes / (time_ms * 1e-3) / 1e12 # Convert to TB/sec - return tb_per_sec - - -def _compute_routing(router_logits: torch.Tensor, top_k: int): - routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.float() - return routing_weights, selected_experts - - def _compute_routing_for_method( routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], @@ -559,19 +389,10 @@ def _compute_routing_for_method( # For other routing methods, use simple top-k as approximation # This is accurate for Default, Renormalize, RenormalizeNaive, TopK # and approximate for Llama4 - _, selected_experts = _compute_routing(routing_logits.float(), top_k) + _, selected_experts = compute_routing(routing_logits.float(), top_k) return selected_experts -def _dynamic_per_tensor_fp8_quant(x: torch.Tensor): - fp8_max = torch.finfo(torch.float8_e4m3fn).max - x_max = x.abs().max().float().clamp(min=1e-6) - scale = x_max / fp8_max - inv_scale = 1.0 / scale - out = (x.float() * inv_scale).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn) - return out, scale.view((1,)) - - def testTrtllmFp4BlockScaleMoe(args): """ Test trtllm_fp4_block_scale_moe API (TensorRT-LLM fused MoE). @@ -677,18 +498,18 @@ def testTrtllmFp4BlockScaleMoe(args): use_ue8m0 = False # Calculate global scale factor for hidden states - hidden_states_scale_global = calculate_fp4_global_scale_factor(hidden_states) + hidden_states_scale_global = calculate_fp4_global_scale(hidden_states) # Quantize weights using proper FP4 quantization gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( - quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True) + quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True) ) gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( - quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True) + quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True) ) # Quantize hidden states - hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quant_fp4_simple( + hidden_states_fp4_bytes, hidden_states_scale_fp4_bytes, _ = quantize_fp4( hidden_states, hidden_states_scale_global, use_ue8m0, True ) @@ -739,15 +560,9 @@ def testTrtllmFp4BlockScaleMoe(args): gemm1_clamp_limit = None gemm2_bias = None - # Create scale scalars (simplified - in practice these would be computed) - output1_scale_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output1_scale_gate_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output2_scale_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 + # Create scale scalars using shared utility + output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar = ( + create_moe_output_scale_scalars(local_num_experts, device) ) if args.verbose >= 2: @@ -857,7 +672,7 @@ def run_fp4_moe( tflops = calculate_moe_tflops( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time ) - tb_per_sec = calculate_moe_bandwidth( + tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, @@ -973,7 +788,7 @@ def testCutlassFusedMoe(args): router_logits = torch.randn( num_tokens, num_experts, dtype=input_dtype, device=device ) - routing_weights, selected_experts = _compute_routing(router_logits, top_k) + routing_weights, selected_experts = compute_routing(router_logits, top_k) if args.verbose >= 2: print(f"[VVERBOSE] x.shape = {x.shape}") @@ -1046,8 +861,8 @@ def run_cutlass(x, selected_experts, routing_weights, w31_local, w2_local, out): for expert_id in range(local_num_experts): w31_expert = w31_local[expert_id] w2_expert = w2_local[expert_id] - w31_q, s31 = _dynamic_per_tensor_fp8_quant(w31_expert) - w2_q, s2 = _dynamic_per_tensor_fp8_quant(w2_expert) + w31_q, s31 = quantize_fp8(w31_expert) + w2_q, s2 = quantize_fp8(w2_expert) w31_weight_fp8[expert_id].copy_(w31_q) w2_weight_fp8[expert_id].copy_(w2_q) # Store the same scalar twice to mimic test layout (avoid torch.tensor()) @@ -1055,7 +870,7 @@ def run_cutlass(x, selected_experts, routing_weights, w31_local, w2_local, out): w31_scales[expert_id, 1] = s31.to(dtype=input_dtype, device=device) w2_scales[expert_id, 0] = s2.to(dtype=input_dtype, device=device) - x_quant, hidden_states_scale = _dynamic_per_tensor_fp8_quant(x) + x_quant, hidden_states_scale = quantize_fp8(x) hidden_states_scale_scalar = hidden_states_scale[0].to(device) # Note: follow tests quant_scales format @@ -1103,8 +918,6 @@ def run_cutlass( elif variant == "nvfp4": # NVFP4: FP4 block-scale weights, optional quantized input - FLOAT4_E2M1_MAX = 6.0 - FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max def round_up(x_val, y): return (x_val + y - 1) // y * y @@ -1220,7 +1033,7 @@ def run_cutlass( tflops = calculate_moe_tflops( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time ) - tb_per_sec = calculate_moe_bandwidth( + tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, @@ -1375,24 +1188,17 @@ def testTrtllmFp8BlockScaleMoe(args): # Optionally shuffle weights and convert to BlockMajorK layout to match kernel expectation if use_shuffled_weight: - # This tile size follows test implementations - epilogue_tile_m = 64 - gemm1_weights_fp8_shuffled = [] gemm2_weights_fp8_shuffled = [] for i in range(num_experts): - tmp_w1 = shuffle_matrix_a( - gemm1_weights_fp8[i].view(torch.uint8), epilogue_tile_m + tmp_w1 = process_fp8_weight_layout( + gemm1_weights_fp8[i], use_shuffled_weight, weight_layout ) - tmp_w2 = shuffle_matrix_a( - gemm2_weights_fp8[i].view(torch.uint8), epilogue_tile_m + tmp_w2 = process_fp8_weight_layout( + gemm2_weights_fp8[i], use_shuffled_weight, weight_layout ) - if weight_layout == WeightLayout.BlockMajorK: - block_k = 128 - tmp_w1 = convert_to_block_layout(tmp_w1, block_k) - tmp_w2 = convert_to_block_layout(tmp_w2, block_k) - gemm1_weights_fp8_shuffled.append(tmp_w1) - gemm2_weights_fp8_shuffled.append(tmp_w2) + gemm1_weights_fp8_shuffled.append(tmp_w1.view(torch.uint8)) + gemm2_weights_fp8_shuffled.append(tmp_w2.view(torch.uint8)) kernel_gemm1_weights = torch.stack(gemm1_weights_fp8_shuffled).view( torch.float8_e4m3fn @@ -1490,7 +1296,7 @@ def run_fp8_block_moe( tflops = calculate_moe_tflops( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time ) - tb_per_sec = calculate_moe_bandwidth( + tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, @@ -1645,15 +1451,9 @@ def testTrtllmFp8PerTensorScaleMoe(args): # Quantize hidden states to FP8 for per-tensor scale hidden_states_fp8 = hidden_states.to(torch.float8_e4m3fn) - # Create per-tensor scale scalars - output1_scales_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output1_scales_gate_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 - ) - output2_scales_scalar = torch.ones( - local_num_experts, device=device, dtype=torch.float32 + # Create per-tensor scale scalars using shared utility + output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar = ( + create_moe_output_scale_scalars(local_num_experts, device) ) if args.verbose >= 2: @@ -1722,7 +1522,7 @@ def run_fp8_per_tensor_moe( tflops = calculate_moe_tflops( num_tokens, hidden_size, intermediate_size, num_experts, top_k, median_time ) - tb_per_sec = calculate_moe_bandwidth( + tb_per_sec = calculate_moe_kernel_bandwidth( num_tokens, hidden_size, intermediate_size, diff --git a/benchmarks/routines/moe_comm.py b/benchmarks/routines/moe_comm.py index cb4686e74f..038634a9e9 100644 --- a/benchmarks/routines/moe_comm.py +++ b/benchmarks/routines/moe_comm.py @@ -51,18 +51,22 @@ --per_phase_timing Options: - --quant_dtype fp8 : FP8 (float8_e4m3fn) with float32 per-tensor scale - --quant_dtype nvfp4 : NVFP4 (4-bit) with float8_e4m3fn block scales - --validate : Run correctness validation before benchmarking. - Uses a deterministic fake MoE to verify round-trip - communication. For non-quantized mode, performs exact - comparison. For quantized mode, validates output - shape and numerical validity. - --per_phase_timing : Enable per-phase timing (dispatch/combine). Adds slight - overhead from CUDA events. - This is less accurate than the total timing but shows - dispatch and combine times separately. - --nvtx : Enable NVTX markers for Nsight Systems profiling. + --quant_dtype fp8 : FP8 (float8_e4m3fn) with float32 per-tensor scale + --quant_dtype nvfp4 : NVFP4 (4-bit) with float8_e4m3fn block scales + --quant_dtype fp8_block_scale : FP8 with block scales (128 elements per block) + --real_math : Run actual MoE kernels (trtllm_fp4/fp8_block_scale_moe) + Supported quant_dtype: nvfp4 and fp8_block_scale only. + --intermediate_size N : FFN intermediate size; must be specified if real_math=True + --validate : Run correctness validation for A2A before benchmarking. + Uses a deterministic fake MoE to verify round-trip + communication. For non-quantized mode, performs exact + comparison. For quantized mode, validates output + shape and numerical validity. + --per_phase_timing : Enable per-phase timing (dispatch/combine). Adds slight + overhead from CUDA events. + This is less accurate than the total timing but shows + dispatch and combine times separately. + --nvtx : Enable NVTX markers for Nsight Systems profiling. """ from collections import defaultdict @@ -77,17 +81,33 @@ from flashinfer.comm import MoeAlltoAll from flashinfer.comm.mapping import Mapping from flashinfer.comm.mnnvl import MnnvlMemory -from flashinfer import fp4_quantize +from flashinfer.fused_moe import ( + trtllm_fp4_block_scale_routed_moe, + trtllm_fp8_block_scale_routed_moe, + WeightLayout, +) from flashinfer.testing.utils import bench_gpu_time from .flashinfer_benchmark_utils import ( dtype_str_to_torch_dtype, print_perf_metrics, ) - -# Constants for FP4 quantization -FLOAT8_E4M3_MAX = 448.0 -FLOAT4_E2M1_MAX = 6.0 +from .moe_utils import ( + add_common_moe_args, + calculate_fp4_global_scale, + quantize_fp4, + dequantize_nvfp4, + quantize_fp8, + quantize_fp8_block_scale, + dequantize_fp8, + dequantize_fp8_block_scale, + pack_topk_ids_triton, + calculate_moe_tflops, + calculate_moe_kernel_bandwidth, + generate_moe_weights, + create_moe_output_scale_scalars, + quantize_and_pack_nvfp4, +) @contextmanager @@ -164,7 +184,6 @@ def run_moe_comm_test(args): """ if args.routine == "moe_a2a_dispatch_combine": return test_moe_a2a_dispatch_combine(args) - # TODO: add a2a_dispatch + moe + a2a_combine else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -180,45 +199,27 @@ def parse_moe_comm_args(line, parser): Returns: Parsed argument namespace """ + # Parse num_tokens/hidden_size/num_experts/top_k/input_dtype in add_common_moe_args + add_common_moe_args(parser) parser.add_argument( - "--num_tokens", - type=int, - required=True, - help="Number of tokens per rank (local batch size).", - ) - parser.add_argument( - "--hidden_size", - type=int, - required=True, - help="Hidden dimension size.", + "--quant_dtype", + type=str, + required=False, + default=None, + choices=["fp8", "nvfp4", "fp8_block_scale"], + help="Quantization format for hidden states. If set, hidden states are quantized and scale factors are communicated. fp8_block_scale: FP8 with block scales (128 elements per block).", ) parser.add_argument( - "--num_experts", - type=int, - required=True, - help="Total number of experts across all ranks.", + "--real_math", + action="store_true", + help="Runs actual MoE kernels (trtllm_(fp4|fp8)_block_scale_moe).", ) parser.add_argument( - "--top_k", + "--intermediate_size", type=int, - required=True, - help="Number of experts to route each token to.", - ) - parser.add_argument( - "--input_dtype", - type=str, - required=False, - default="bfloat16", - choices=["bfloat16", "float16"], - help="Data type for hidden states payload (before quantization if quant_dtype is set).", - ) - parser.add_argument( - "--quant_dtype", - type=str, required=False, default=None, - choices=["fp8", "nvfp4"], - help="Quantization format for hidden states. If set, hidden states are quantized and block-scale scale factors are communicated.", + help="Intermediate size for each expert in MoE. Must be specified if real_math=True.", ) parser.add_argument( "--max_num_tokens", @@ -235,7 +236,7 @@ def parse_moe_comm_args(line, parser): parser.add_argument( "--nvtx", action="store_true", - help="Enable NVTX markers for Nsight Systems profiling.", + help="Enable NVTX markers for Nsight Systems profiling. This also turns on --use_cuda_events.", ) parser.add_argument( "--per_phase_timing", @@ -249,11 +250,27 @@ def parse_moe_comm_args(line, parser): if args.max_num_tokens is None: args.max_num_tokens = args.num_tokens + if args.real_math: + # Must specify intermediate_size if real_math=True + assert args.intermediate_size is not None, ( + "intermediate_size must be specified if real_math=True" + ) + # Must specify quant_dtype as one of the following: nvfp4, fp8_block_scale + # Other quant_dtype support is TBD + assert args.quant_dtype in [ + "nvfp4", + "fp8_block_scale", + ], ( + f"real_math=True requires quant_dtype 'nvfp4' or 'fp8_block_scale', got '{args.quant_dtype}'" + ) + # Derive scale_dtype from quant_dtype if args.quant_dtype == "nvfp4": args.scale_dtype = torch.float8_e4m3fn elif args.quant_dtype == "fp8": args.scale_dtype = torch.float32 + elif args.quant_dtype == "fp8_block_scale": + args.scale_dtype = torch.float32 # Block scales are float32 else: args.scale_dtype = None @@ -282,147 +299,150 @@ def _setup_mpi_and_device() -> Tuple[MPI.Comm, int, int, int]: return comm, rank, world_size, local_rank -def _calculate_fp4_global_scale(tensor: torch.Tensor) -> torch.Tensor: - """Calculate global scale factor for FP4 quantization.""" - tensor_amax = tensor.abs().max().to(torch.float32) - if tensor_amax == 0.0: - global_scale = torch.tensor(0.0, dtype=torch.float32, device=tensor.device) - else: - global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / tensor_amax - return global_scale - - -def _quantize_to_fp8( - hidden_states: torch.Tensor, - scale: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize hidden states to FP8 (per-tensor scale). - - Args: - hidden_states: Input tensor to quantize - scale: Optional pre-computed scale. If None, computed from hidden_states. - - Returns: - Tuple of (quantized_hidden_states, scale_factor) - """ - fp8_max = torch.finfo(torch.float8_e4m3fn).max - if scale is None: - amax = hidden_states.abs().max().float().clamp(min=1e-6) - scale = amax / fp8_max - inv_scale = 1.0 / scale if scale != 0.0 else 0.0 - quantized = ( - (hidden_states.float() * inv_scale) - .clamp(-fp8_max, fp8_max) - .to(torch.float8_e4m3fn) - ) - return quantized, scale.view(1) - - -def _dequantize_fp8_to_dtype( - tensor_fp8: torch.Tensor, - scale: torch.Tensor, - dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - Dequantize FP8 tensor back to high precision. - - Args: - tensor_fp8: FP8 quantized tensor (float8_e4m3fn) - scale: Per-tensor scale factor - dtype: Output dtype - - Returns: - Dequantized tensor in specified dtype - """ - return (tensor_fp8.float() * scale.float()).to(dtype) - - -def _quantize_to_nvfp4( - hidden_states: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def _init_moe_weights( + num_experts_local: int, + hidden_size: int, + intermediate_size: int, + quant_dtype: str, + device: torch.device, +) -> dict: """ - Quantize hidden states to NVFP4 (block scale). + Initialize MoE weights for MoE. + The function does not do weight shuffling required for TRTLLM-Gen MoE as it serves benchmark purposes only. + However, A2A with fake MoE kernel can be validated through flags (--validate). Args: - hidden_states: Input tensor to quantize - global_scale: Optional pre-computed global scale. If None, computed from hidden_states. + num_experts_local: Number of local experts on this rank + hidden_size: Hidden dimension size + intermediate_size: Intermediate FFN size + quant_dtype: "nvfp4" or "fp8_block_scale" + device: CUDA device Returns: - Tuple of (quantized_hidden_states, block_scale_factors, global_scale_factor) - - quantized_hidden_states: uint8 tensor, packed (2 FP4 values per byte) - - block_scale_factors: float8_e4m3fn tensor, shape [num_tokens, hidden_size // 16] - - global_scale_factor: float32 scalar + Dictionary containing weights and scales for MoE computation """ - if global_scale is None: - global_scale = _calculate_fp4_global_scale(hidden_states) - sf_vec_size = 16 - use_ue8m0 = False + weights = {} - # Activation always uses linear (i.e., non-swizzled) layout - is_sf_swizzled_layout = False + # Create quantized weights + # Note: Generate and quantize one expert at a time to avoid OOM with large expert counts + # gemm1: [num_experts, 2*intermediate_size, hidden_size] + # gemm2: [num_experts, hidden_size, intermediate_size] + if quant_dtype == "nvfp4": + # Create FP4 quantized weights + + # Quantize to FP4 using swizzled layout for weights + sf_vec_size = 16 + use_ue8m0 = False + is_sf_swizzled_layout = True + + gemm1_fp4_list = [] + gemm1_sf_list = [] + gemm1_global_sf_list = [] + gemm2_fp4_list = [] + gemm2_sf_list = [] + gemm2_global_sf_list = [] + for _ in range(num_experts_local): + # Generate bf16 weights for this expert using shared utility + w1_batch, w2_batch = generate_moe_weights( + 1, hidden_size, intermediate_size, device, dtype=torch.bfloat16 + ) + expert_w1_bf16 = w1_batch.squeeze(0) + expert_w2_bf16 = w2_batch.squeeze(0) + del w1_batch, w2_batch + + # Quantize gemm1 weights using moe_utils.quantize_fp4 + quantized, sf, global_sf = quantize_fp4( + expert_w1_bf16, + global_scale=None, + use_ue8m0=use_ue8m0, + is_sf_swizzled_layout=is_sf_swizzled_layout, + ) + gemm1_fp4_list.append(quantized.view(torch.uint8)) + gemm1_sf_list.append(sf.view(torch.float8_e4m3fn)) + gemm1_global_sf_list.append(global_sf) + del expert_w1_bf16 + + # Quantize gemm2 weights using moe_utils.quantize_fp4 + quantized, sf, global_sf = quantize_fp4( + expert_w2_bf16, + global_scale=None, + use_ue8m0=use_ue8m0, + is_sf_swizzled_layout=is_sf_swizzled_layout, + ) - # Returns (quantized_data, block_scales) - quantized, block_scales = fp4_quantize( - hidden_states, global_scale, sf_vec_size, use_ue8m0, is_sf_swizzled_layout - ) + # NOTE: the script chooses not to do weight shuffling as it is intended for benchmarks; + # only A2A with fake MoE kernel is validated - # Reshape quantized data: pack 2 FP4 values into 1 byte - num_tokens, hidden_size = hidden_states.shape - quantized_packed = quantized.view(torch.uint8).reshape(num_tokens, hidden_size // 2) + gemm2_fp4_list.append(quantized.view(torch.uint8)) + gemm2_sf_list.append(sf.view(torch.float8_e4m3fn)) + gemm2_global_sf_list.append(global_sf) + del expert_w2_bf16 - # Block scales are float8_e4m3fn - block_scales_reshaped = block_scales.view(torch.float8_e4m3fn).reshape( - num_tokens, hidden_size // sf_vec_size - ) + # Stack and reshape + weights["gemm1_weights"] = torch.stack(gemm1_fp4_list).reshape( + num_experts_local, 2 * intermediate_size, hidden_size // 2 + ) + weights["gemm1_weights_scale"] = torch.stack(gemm1_sf_list).reshape( + num_experts_local, 2 * intermediate_size, hidden_size // sf_vec_size + ) + weights["gemm2_weights"] = torch.stack(gemm2_fp4_list).reshape( + num_experts_local, hidden_size, intermediate_size // 2 + ) + weights["gemm2_weights_scale"] = torch.stack(gemm2_sf_list).reshape( + num_experts_local, hidden_size, intermediate_size // sf_vec_size + ) - return quantized_packed, block_scales_reshaped, global_scale - - -# Copied/adapted from tests/moe/test_trtllm_cutlass_fused_moe.py -def _dequantize_nvfp4_to_dtype( - tensor_fp4: torch.Tensor, - tensor_sf: torch.Tensor, - global_scale: torch.Tensor, - block_size: int = 16, - dtype: torch.dtype = torch.float32, -): - """Dequantize the fp4 tensor back to high precision.""" - - def break_fp4_bytes(a, dtype): - assert a.dtype == torch.uint8 - m, n = a.shape - # Vectorized nibble processing - a_flat = a.flatten() - high = (a_flat & 0xF0) >> 4 # Upper nibbles - low = a_flat & 0x0F # Lower nibbles - # Combine nibbles for batch processing - combined = torch.stack((low, high), dim=1).flatten() - # Vectorized sign and magnitude extraction - signs = (combined & 0x08).to(torch.bool) # Sign bits - abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices - # Device-aware lookup and sign application - kE2M1ToFloat = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 + # Scale scalars for output using shared utility + ( + weights["output1_scale_scalar"], + weights["output1_scale_gate_scalar"], + weights["output2_scale_scalar"], + ) = create_moe_output_scale_scalars(num_experts_local, device) + + elif quant_dtype == "fp8_block_scale": + # Create FP8 block-scaled weights + + # Optionally shuffle weights using shared utility + gemm1_weights = [] + gemm2_weights = [] + for _ in range(num_experts_local): + # Generate bf16 weights for this expert using shared utility + w1_batch, w2_batch = generate_moe_weights( + 1, hidden_size, intermediate_size, device, dtype=torch.bfloat16 + ) + expert_w1_bf16 = w1_batch.squeeze(0) + expert_w2_bf16 = w2_batch.squeeze(0) + del w1_batch, w2_batch + + expert_w1_fp8 = expert_w1_bf16.to(torch.float8_e4m3fn) + expert_w2_fp8 = expert_w2_bf16.to(torch.float8_e4m3fn) + del expert_w1_bf16, expert_w2_bf16 # Free memory immediately + + # NOTE: the script chooses not to do weight shuffling as it is intended for benchmarks; + # only A2A with fake MoE kernel is validated + + gemm1_weights.append(expert_w1_fp8) + gemm2_weights.append(expert_w2_fp8) + del expert_w1_fp8, expert_w2_fp8 # Free memory immediately + weights["gemm1_weights"] = torch.stack(gemm1_weights) + weights["gemm2_weights"] = torch.stack(gemm2_weights) + + # Block scales: [num_experts, out_dim // 128, in_dim // 128] + weights["gemm1_weights_scale"] = 2.0 * torch.ones( + (num_experts_local, 2 * intermediate_size // 128, hidden_size // 128), + device=device, + dtype=torch.float32, + ) + weights["gemm2_weights_scale"] = 2.0 * torch.ones( + (num_experts_local, hidden_size // 128, intermediate_size // 128), + device=device, + dtype=torch.float32, ) - kE2M1 = kE2M1ToFloat.to(device=a.device) - values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) - # Reshape to final form - return values.reshape(m, n * 2).to(dtype=dtype) - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape - k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + else: + raise ValueError(f"Unsupported quant_dtype for real computation: {quant_dtype}") - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) - return out.to(dtype=dtype) + return weights def _create_moe_inputs( @@ -452,7 +472,7 @@ def _create_moe_inputs( num_experts: Total number of experts top_k: Number of experts per token input_dtype: Data type for hidden states (before quantization) - quant_dtype: None, "fp8", or "nvfp4" + quant_dtype: None, "fp8", "nvfp4", or "fp8_block_scale" device: CUDA device comm: MPI communicator for syncing global scale @@ -491,15 +511,18 @@ def _create_moe_inputs( global_scale = None if quant_dtype == "nvfp4": # Compute local global scale, sync via max, then quantize - local_global_scale = _calculate_fp4_global_scale(hidden_states_original) + local_global_scale = calculate_fp4_global_scale(hidden_states_original) synced_global_scale = comm.allreduce( local_global_scale.cpu().item(), op=MPI.MAX ) global_scale = torch.tensor( synced_global_scale, dtype=torch.float32, device=device ) - hidden_states, scale_factor, global_scale = _quantize_to_nvfp4( - hidden_states_original, global_scale + hidden_states, scale_factor, global_scale = quantize_and_pack_nvfp4( + hidden_states_original, + global_scale, + use_ue8m0=False, + is_sf_swizzled_layout=False, ) elif quant_dtype == "fp8": # Compute local amax, sync via max, then quantize with synced scale @@ -509,9 +532,16 @@ def _create_moe_inputs( synced_scale = torch.tensor( synced_amax / fp8_max, dtype=torch.float32, device=device ) - hidden_states, global_scale = _quantize_to_fp8( - hidden_states_original, synced_scale + hidden_states, global_scale = quantize_fp8(hidden_states_original, synced_scale) + elif quant_dtype == "fp8_block_scale": + # FP8 with block scales (128 elements per block) + # Block scales shape: [hidden_size // 128, num_tokens] + hidden_states, scale_factor = quantize_fp8_block_scale( + hidden_states_original, block_size=128 ) + # Transpose scale_factor to [num_tokens, hidden_size // 128] for A2A payload + # A2A expects [num_tokens, *] shape for payloads + scale_factor = scale_factor.transpose(0, 1).contiguous() else: # No quantization hidden_states = hidden_states_original @@ -646,7 +676,7 @@ def _calculate_comm_bandwidth( ep_size: Expert parallel size (number of ranks) time_ms: Time in milliseconds input_dtype: Data type of hidden states (before quantization) - quant_dtype: None, "fp8", or "nvfp4" + quant_dtype: None, "fp8", "nvfp4", or "fp8_block_scale" phase: "dispatch", "combine", or "dispatch_combine" actual_traffic: Optional tuple of (dispatch_bytes, combine_bytes) from actual routing. If provided, uses exact traffic instead of uniform distribution estimate. @@ -674,6 +704,11 @@ def _calculate_comm_bandwidth( # FP8: 1 byte per element, no scale payload in A2A hidden_states_bytes = num_tokens * hidden_size * 1 # float8_e4m3fn = 1 byte scale_bytes = 0 + elif quant_dtype == "fp8_block_scale": + # FP8 with block scales: 1 byte per element + block scales (128 elements per block) + hidden_states_bytes = num_tokens * hidden_size * 1 # float8_e4m3fn = 1 byte + # Block scales: float32, one per 128 elements + scale_bytes = num_tokens * (hidden_size // 128) * 4 # float32 = 4 bytes else: # No quantization element_size = torch.tensor([], dtype=input_dtype).element_size() @@ -760,7 +795,7 @@ def fake_moe( # Deterministic scale based on expert_id scale = (expert_id + 1.0) / num_experts + 0.5 - results.append(hidden_states[token_idx] * scale) + results.append(hidden_states[token_idx].to(torch.float32) * scale) # Sum results with higher precision to match actual implementation if results: @@ -848,7 +883,7 @@ def _invoke_print_ordered(msg, condition=True): ) if quant_dtype == "nvfp4": for i in range(recv_hidden.shape[0]): - recv_hidden_dequant[i] = _dequantize_nvfp4_to_dtype( + recv_hidden_dequant[i] = dequantize_nvfp4( recv_hidden[i], recv_scale_factor[i], global_scale, @@ -857,11 +892,22 @@ def _invoke_print_ordered(msg, condition=True): ) elif quant_dtype == "fp8": for i in range(recv_hidden.shape[0]): - recv_hidden_dequant[i] = _dequantize_fp8_to_dtype( + recv_hidden_dequant[i] = dequantize_fp8( recv_hidden[i], global_scale, dtype=input_dtype, ) + elif quant_dtype == "fp8_block_scale": + for i in range(recv_hidden.shape[0]): + # Transpose scales from [max_tokens, hidden_size // block_size] to + # [hidden_size // block_size, max_tokens] + scales_transposed = recv_scale_factor[i].transpose(0, 1).contiguous() + recv_hidden_dequant[i] = dequantize_fp8_block_scale( + recv_hidden[i], + scales_transposed, + block_size=128, + dtype=input_dtype, + ) else: recv_hidden_dequant = recv_hidden @@ -942,6 +988,8 @@ def _invoke_print_ordered(msg, condition=True): atol, rtol = 2.0, 0.5 # FP4: very loose tolerance due to 4-bit precision elif quant_dtype == "fp8": atol, rtol = 0.1, 0.1 # FP8: moderate tolerance + elif quant_dtype == "fp8_block_scale": + atol, rtol = 0.5, 0.1 # FP8 block scale: slightly worse than per-tensor FP8 else: atol, rtol = 1e-2, 1e-2 # Non-quantized: tight tolerance @@ -1046,6 +1094,7 @@ def test_moe_a2a_dispatch_combine(args): # Synchronize all_num_tokens across ranks all_num_tokens = comm.allgather(num_tokens) runtime_max_tokens_per_rank = max(all_num_tokens) + sum_all_num_tokens = sum(all_num_tokens) # Create input data torch.manual_seed(args.random_seed + rank) @@ -1106,17 +1155,50 @@ def test_moe_a2a_dispatch_combine(args): input_dtype, quant_dtype, ) + # Compute total active experts across all ranks + total_active_experts = int( + np.unique(np.concatenate(all_token_selected_experts).flatten()).size + ) if rank == 0 and args.verbose >= 1: print( - f"[INFO] Actual inter-rank traffic: dispatch={dispatch_bytes / 1024**2:.3f} MiB, combine={combine_bytes / 1024**2:.3f} MiB" + f"[INFO] Inter-rank traffic: dispatch={dispatch_bytes / 1024**2:.3f} MiB, combine={combine_bytes / 1024**2:.3f} MiB" ) # Storage for per-phase CUDA events to be populated later during benchmark # Deferred timing: collect events during iterations, compute times after single sync dispatch_events = [] combine_events = [] + moe_events = [] # For MoE kernel timing (excluding packing) enable_nvtx = getattr(args, "nvtx", False) + if enable_nvtx: + # CUPTI complains subscribers when using CUPTI for timing and nsys profiling at the same time + args.use_cuda_events = True enable_per_phase_timing = getattr(args, "per_phase_timing", False) + enable_real_math = getattr(args, "real_math", False) + intermediate_size = getattr(args, "intermediate_size", None) + num_experts_local = num_experts // ep_size + + if enable_real_math: + assert intermediate_size is not None, ( + "intermediate_size must be specified if -real_math=True" + ) + + # Initialize MoE weights for real computation mode + moe_weights = None + if enable_real_math: + if quant_dtype not in ["nvfp4", "fp8_block_scale"]: + if rank == 0: + print( + f"[ERROR] Real MoE math requires quant_dtype 'nvfp4' or 'fp8_block_scale', got '{quant_dtype}'" + ) + return res + moe_weights = _init_moe_weights( + num_experts_local=num_experts_local, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + quant_dtype=quant_dtype, + device=device, + ) # Define benchmark function that accepts tensors as arguments # This enables automatic buffer rotation by bench_gpu_time @@ -1126,22 +1208,140 @@ def run_dispatch_combine(sel_experts, *payloads): nvtx_range("moe_a2a_dispatch", enable_nvtx), cuda_event_timer(dispatch_events, enable_per_phase_timing), ): - _ = moe_a2a.dispatch( + recv_tensors = moe_a2a.dispatch( sel_experts, list(payloads), runtime_max_tokens_per_rank, ) - # Simulate expert processing output - # Dispatch sends quantized data (e.g., fp8/nvfp4), - # Combine receives data in activation dtype (e.g., bfloat16/float16) - with nvtx_range("moe_a2a_fake_math", enable_nvtx): + # Expert processing in benchmark runs either no-op or real MoE kernel depending on --real_math flag + with nvtx_range("moe_compute", enable_nvtx): combine_payload = moe_a2a.get_combine_payload_tensor_in_workspace( runtime_max_tokens_per_rank, hidden_size, input_dtype, ) - # TODO: add real math here if user prefers + + if enable_real_math and moe_weights is not None: + # Real computation using actual MoE kernels + # recv_tensors[0]: the received hidden states [ep_size, max_tokens, hidden_size] + # recv_tensors[1]: the received expert IDs [ep_size, max_tokens, top_k] + # recv_tensors[2]: the received token final scales [ep_size, max_tokens] + # recv_tensors[3]: the received scale factor [ep_size, max_tokens, hidden_size // block_size] + recv_hidden = recv_tensors[0] + recv_experts = recv_tensors[1] + recv_token_final_scales = recv_tensors[2] + recv_scale_factor = recv_tensors[3] if len(recv_tensors) > 3 else None + + # Flatten for MoE kernel: [ep_size * max_tokens, hidden_size] + total_tokens = ep_size * runtime_max_tokens_per_rank + + if quant_dtype == "nvfp4": + # Reshape hidden states for FP4 kernel + hidden_flat = recv_hidden.reshape(total_tokens, -1) + # Reshape scale factors + scale_flat = ( + recv_scale_factor.reshape(total_tokens, -1) + if recv_scale_factor is not None + else None + ) + + # Pack expert IDs with actual routing weights using fused Triton kernel + local_expert_offset = rank * num_experts_local + recv_experts_flat = recv_experts.reshape(total_tokens, top_k) + recv_weights_flat = recv_token_final_scales.reshape( + total_tokens, top_k + ) + packed_topk_ids = pack_topk_ids_triton( + recv_experts_flat, + recv_weights_flat, + local_expert_offset, + ) + + # Run block scale routed MoE + with cuda_event_timer(moe_events, enable_per_phase_timing): + trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hidden_flat, + hidden_states_scale=scale_flat, + gemm1_weights=moe_weights["gemm1_weights"], + gemm1_weights_scale=moe_weights["gemm1_weights_scale"], + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=moe_weights["gemm2_weights"], + gemm2_weights_scale=moe_weights["gemm2_weights_scale"], + gemm2_bias=None, + output1_scale_scalar=moe_weights["output1_scale_scalar"], + output1_scale_gate_scalar=moe_weights[ + "output1_scale_gate_scalar" + ], + output2_scale_scalar=moe_weights["output2_scale_scalar"], + num_experts=num_experts_local, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts_local, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize: TopK -> Softmax + output=combine_payload.view(total_tokens, hidden_size), + ) + + elif quant_dtype == "fp8_block_scale": + # Reshape for FP8 block scale kernel + hidden_flat = recv_hidden.reshape(total_tokens, hidden_size) + # Transpose scale for kernel: [hidden_size // 128, total_tokens] + if recv_scale_factor is not None: + scale_flat = ( + recv_scale_factor.reshape(total_tokens, -1) + .transpose(0, 1) + .contiguous() + ) + + # Pack expert IDs with actual routing weights using fused Triton kernel + local_expert_offset = rank * num_experts_local + recv_experts_flat = recv_experts.reshape(total_tokens, top_k) + recv_weights_flat = recv_token_final_scales.reshape( + total_tokens, top_k + ) + packed_topk_ids = pack_topk_ids_triton( + recv_experts_flat, + recv_weights_flat, + local_expert_offset, + ) + + # Convert hidden states to FP8 + hidden_fp8 = hidden_flat.to(torch.float8_e4m3fn) + + # Run block scale routed MoE + with cuda_event_timer(moe_events, enable_per_phase_timing): + trtllm_fp8_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hidden_fp8, + hidden_states_scale=scale_flat, + gemm1_weights=moe_weights["gemm1_weights"], + gemm1_weights_scale=moe_weights["gemm1_weights_scale"], + gemm2_weights=moe_weights["gemm2_weights"], + gemm2_weights_scale=moe_weights["gemm2_weights_scale"], + num_experts=num_experts_local, + top_k=top_k, + n_group=0, + topk_group=0, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts_local, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize: TopK -> Softmax + use_shuffled_weight=False, + weight_layout=int(WeightLayout.MajorK), + enable_pdl=True, + output=combine_payload.view(total_tokens, hidden_size), + ) # Combine phase: gather processed outputs from all ranks with ( @@ -1180,18 +1380,26 @@ def run_dispatch_combine(sel_experts, *payloads): # Per-phase events include dry runs; only use the last num_measure_iters entries dispatch_events_measure = dispatch_events[-num_measure_iters:] combine_events_measure = combine_events[-num_measure_iters:] + moe_events_measure = moe_events[-num_measure_iters:] if moe_events else [] # Convert events to times (no additional sync needed - bench_gpu_time already synced) dispatch_times = [s.elapsed_time(e) for s, e in dispatch_events_measure] combine_times = [s.elapsed_time(e) for s, e in combine_events_measure] + moe_times = ( + [s.elapsed_time(e) for s, e in moe_events_measure] + if moe_events_measure + else [] + ) else: dispatch_times = [] combine_times = [] + moe_times = [] # Gather times from all ranks all_total_times = comm.allgather(total_times) all_dispatch_times = comm.allgather(dispatch_times) all_combine_times = comm.allgather(combine_times) + all_moe_times = comm.allgather(moe_times) # Compute statistics from rank 0 if rank == 0: @@ -1219,7 +1427,9 @@ def run_dispatch_combine(sel_experts, *payloads): # Per-phase statistics if enabled --per_phase_timing flag median_time_dispatch, std_time_dispatch = np.nan, np.nan median_time_combine, std_time_combine = np.nan, np.nan + median_time_moe, std_time_moe = np.nan, np.nan tb_per_sec_dispatch, tb_per_sec_combine = np.nan, np.nan + tflops_moe, tb_per_sec_moe = np.nan, np.nan if enable_per_phase_timing: dispatch_per_iter_max = [ max(t[i] for t in all_dispatch_times) for i in range(num_measure_iters) @@ -1232,6 +1442,14 @@ def run_dispatch_combine(sel_experts, *payloads): median_time_combine = np.median(combine_per_iter_max) std_time_combine = np.std(combine_per_iter_max) + # MoE timing is only available when real_math is enabled + if all_moe_times and all_moe_times[0]: + moe_per_iter_max = [ + max(t[i] for t in all_moe_times) for i in range(num_measure_iters) + ] + median_time_moe = np.median(moe_per_iter_max) + std_time_moe = np.std(moe_per_iter_max) + tb_per_sec_dispatch = _calculate_comm_bandwidth( num_tokens, hidden_size, @@ -1263,6 +1481,39 @@ def run_dispatch_combine(sel_experts, *payloads): torch.nan, tb_per_sec_dispatch, ) + # Only print MoE timing when real_math is enabled + if args.real_math: + # This is the total FLOPS of all ranks, not per rank + tflops_moe = calculate_moe_tflops( + sum_all_num_tokens, + hidden_size, + intermediate_size, + num_experts, # Actually not used + top_k, + median_time_moe, + ) + # This is the total bandwidth of all ranks, not per rank + tb_per_sec_moe = calculate_moe_kernel_bandwidth( + sum_all_num_tokens, + hidden_size, + intermediate_size, + num_experts, + top_k, + median_time_moe, + input_dtype, + input_dtype, + input_format=quant_dtype, + weight_format=quant_dtype, + routing_logits_dtype=None, # No routing logits in routed MoE + active_experts=total_active_experts, + ) + print_perf_metrics( + "moe_kernel", + median_time_moe, + std_time_moe, + tflops_moe, + tb_per_sec_moe, + ) print_perf_metrics( "a2a_combine", median_time_combine, @@ -1276,7 +1527,7 @@ def run_dispatch_combine(sel_experts, *payloads): "a2a_total", median_time, std_time, torch.nan, tb_per_sec_total ) print( - "[INFO] The reported achieved tb_per_sec is the aggregate bandwidth of all participating ranks." + "[INFO] The reported achieved tflops/tb_per_sec is the aggregate FLOPS/bandwidth of all participating ranks based on timing results of rank 0. Could observe rank-to-rank variations." ) if args.output_path is not None: @@ -1286,6 +1537,10 @@ def run_dispatch_combine(sel_experts, *payloads): cur_res["std_time"] = std_time cur_res["dispatch_time"] = median_time_dispatch cur_res["dispatch_std"] = std_time_dispatch + cur_res["moe_time"] = median_time_moe + cur_res["moe_std"] = std_time_moe + cur_res["moe_tflops"] = tflops_moe + cur_res["moe_tb_per_sec"] = tb_per_sec_moe cur_res["combine_time"] = median_time_combine cur_res["combine_std"] = std_time_combine cur_res["tflops"] = "N/A" diff --git a/benchmarks/routines/moe_utils.py b/benchmarks/routines/moe_utils.py new file mode 100644 index 0000000000..64d9511145 --- /dev/null +++ b/benchmarks/routines/moe_utils.py @@ -0,0 +1,766 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MoE Benchmark Utilities + +Shared helper functions for MoE benchmarks including: +- FP4/FP8 quantization and dequantization +- Performance metrics calculation +- Routing utilities +- Triton kernels for expert ID packing +- Common argument parsing +- Weight layout processing +""" + +import argparse +from typing import Optional, Tuple + +import torch + +import triton +import triton.language as tl + +from flashinfer import fp4_quantize, shuffle_matrix_a +from flashinfer.fused_moe import WeightLayout, convert_to_block_layout + +FLOAT8_E4M3_MAX = 448.0 +FLOAT4_E2M1_MAX = 6.0 + + +def generate_moe_weights( + num_experts: int, + hidden_size: int, + intermediate_size: int, + device: torch.device, + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate random weights for MoE experts. + + Args: + num_experts: Number of experts to generate weights for + hidden_size: Hidden dimension size + intermediate_size: Intermediate FFN dimension + device: Device to create tensors on + dtype: Data type for the weights (default: bfloat16) + + Returns: + gemm1_weights: [num_experts, 2 * intermediate_size, hidden_size] + gemm2_weights: [num_experts, hidden_size, intermediate_size] + """ + gemm1_weights = torch.randn( + (num_experts, 2 * intermediate_size, hidden_size), + device=device, + dtype=dtype, + ) + gemm2_weights = torch.randn( + (num_experts, hidden_size, intermediate_size), + device=device, + dtype=dtype, + ) + return gemm1_weights, gemm2_weights + + +def calculate_fp4_global_scale(tensor: torch.Tensor) -> torch.Tensor: + """ + Calculate global scale factor for FP4 quantization. + + Args: + tensor: Input tensor to compute scale for + + Returns: + Global scale factor as a scalar tensor + """ + tensor_amax = tensor.abs().max().to(torch.float32) + if tensor_amax == 0.0: + global_scale = torch.tensor(0.0, dtype=torch.float32, device=tensor.device) + else: + global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / tensor_amax + return global_scale + + +def quantize_fp4( + tensor: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantize tensor to FP4 format. + + Args: + tensor: Input tensor to quantize + global_scale: Optional pre-computed global scale. If None, computed from tensor. + use_ue8m0: Whether to use UE8M0 format + is_sf_swizzled_layout: Whether to use swizzled layout for scale factors + + Returns: + Tuple of (quantized_data, block_scale_factors, global_scale_factor) + - quantized_data: uint8 tensor with packed FP4 values + - block_scale_factors: float8_e4m3fn tensor + - global_scale_factor: float32 scalar tensor + """ + sf_vec_size = 16 + + if global_scale is None: + global_scale = calculate_fp4_global_scale(tensor) + + quantized, block_scales = fp4_quantize( + tensor, global_scale, sf_vec_size, use_ue8m0, is_sf_swizzled_layout + ) + + return quantized, block_scales, global_scale + + +def quantize_fp4_batched( + tensor: torch.Tensor, + num_experts: int, + use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantize batched tensor to FP4 format, computing per-expert global scales. + + Args: + tensor: Input tensor of shape [num_experts, ...] + num_experts: Number of experts in the batch + use_ue8m0: Whether to use UE8M0 format + is_sf_swizzled_layout: Whether to use swizzled layout for scale factors + + Returns: + Tuple of (quantized_data, block_scale_factors, global_scale_factors) + """ + quant_list = [] + sf_list = [] + global_sf_list = [] + + for i in range(num_experts): + global_sf = calculate_fp4_global_scale(tensor[i]) + quantized, block_sf, _ = quantize_fp4( + tensor[i], global_sf, use_ue8m0, is_sf_swizzled_layout + ) + quant_list.append(quantized) + sf_list.append(block_sf) + global_sf_list.append(global_sf) + + return ( + torch.stack(quant_list), + torch.stack(sf_list), + torch.stack(global_sf_list), + ) + + +# Adapted from tests/moe/test_trtllm_cutlass_fused_moe.py +def dequantize_nvfp4( + tensor_fp4: torch.Tensor, + tensor_sf: torch.Tensor, + global_scale: torch.Tensor, + block_size: int = 16, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantize FP4 tensor back to high precision. + + Args: + tensor_fp4: FP4 quantized tensor (uint8, packed) + tensor_sf: Block scale factors + global_scale: Global scale factor + block_size: Number of elements per scale block + dtype: Output dtype + + Returns: + Dequantized tensor in specified dtype + """ + + def break_fp4_bytes(a, out_dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + # Device-aware lookup and sign application + kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 + ) + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=out_dtype) + + # Two fp4 values are packed into one uint8 + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # Scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +def quantize_fp8( + tensor: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize tensor to FP8 (per-tensor scale). + + Args: + tensor: Input tensor to quantize + scale: Optional pre-computed scale. If None, computed from tensor. + + Returns: + Tuple of (quantized_tensor, scale_factor) + """ + fp8_max = torch.finfo(torch.float8_e4m3fn).max + if scale is None: + amax = tensor.abs().max().float().clamp(min=1e-6) + scale = amax / fp8_max + inv_scale = 1.0 / scale if scale != 0.0 else 0.0 + quantized = ( + (tensor.float() * inv_scale).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn) + ) + return quantized, scale.view(1) if scale.dim() == 0 else scale + + +def quantize_fp8_block_scale( + tensor: torch.Tensor, + block_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize tensor to FP8 with block scales. + + For trtllm_fp8_block_scale_moe, hidden_states_scale shape is [hidden_size // 128, num_tokens]. + + Args: + tensor: Input tensor [num_tokens, hidden_size] + block_size: Number of elements per scale block (default 128) + + Returns: + Tuple of (quantized_tensor, block_scales) + - quantized_tensor: float8_e4m3fn tensor [num_tokens, hidden_size] + - block_scales: float32 tensor [hidden_size // block_size, num_tokens] + """ + num_tokens, hidden_size = tensor.shape + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + # Compute per-block amax and scales + # Reshape to [num_tokens, num_blocks, block_size] + num_blocks = hidden_size // block_size + reshaped = tensor.float().reshape(num_tokens, num_blocks, block_size) + + # Compute amax per block: [num_tokens, num_blocks] + block_amax = reshaped.abs().amax(dim=-1).clamp(min=1e-6) + + # Compute scales: [num_tokens, num_blocks] + block_scales = block_amax / fp8_max + + # Quantize each block + inv_scales = 1.0 / block_scales # [num_tokens, num_blocks] + # Expand for broadcasting: [num_tokens, num_blocks, 1] + inv_scales_expanded = inv_scales.unsqueeze(-1) + quantized_reshaped = (reshaped * inv_scales_expanded).clamp(-fp8_max, fp8_max) + quantized = quantized_reshaped.reshape(num_tokens, hidden_size).to( + torch.float8_e4m3fn + ) + + # Transpose block_scales to [num_blocks, num_tokens] as expected by kernel + block_scales_transposed = block_scales.transpose(0, 1).contiguous() + + return quantized, block_scales_transposed + + +def dequantize_fp8( + tensor_fp8: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """ + Dequantize FP8 tensor back to high precision. + + Args: + tensor_fp8: FP8 quantized tensor (float8_e4m3fn) + scale: Per-tensor scale factor + dtype: Output dtype + + Returns: + Dequantized tensor in specified dtype + """ + return (tensor_fp8.float() * scale.float()).to(dtype) + + +def dequantize_fp8_block_scale( + tensor_fp8: torch.Tensor, + block_scales: torch.Tensor, + block_size: int = 128, + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """ + Dequantize block-scaled FP8 tensor back to high precision. + + Args: + tensor_fp8: FP8 quantized tensor [num_tokens, hidden_size] + block_scales: Block scales [hidden_size // block_size, num_tokens] + block_size: Number of elements per scale block + dtype: Output dtype + + Returns: + Dequantized tensor in specified dtype + """ + num_tokens, hidden_size = tensor_fp8.shape + num_blocks = hidden_size // block_size + + # Reshape tensor for block-wise dequantization + reshaped = tensor_fp8.float().reshape(num_tokens, num_blocks, block_size) + + # Transpose scales from [num_blocks, num_tokens] to [num_tokens, num_blocks] + block_scales_t = block_scales.transpose(0, 1).contiguous() + + # Apply scales + scales_expanded = block_scales_t.unsqueeze(-1) # [num_tokens, num_blocks, 1] + dequantized = reshaped * scales_expanded + + return dequantized.reshape(num_tokens, hidden_size).to(dtype) + + +@triton.jit +def _pack_topk_ids_kernel( + expert_ids_ptr, # [total_tokens, top_k] int32/int64 + expert_weights_ptr, # [total_tokens, top_k] float32 + output_ptr, # [total_tokens, top_k] int32 + local_expert_offset, # scalar int + stride_ids_row, # stride for expert_ids row dimension + stride_ids_col, # stride for expert_ids col dimension + stride_weights_row, # stride for weights row dimension + stride_weights_col, # stride for weights col dimension + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Fused kernel to pack expert IDs with actual weights into packed format: + packed = ((expert_id - local_offset) << 16) | (weight_as_bf16 bits) + + This eliminates: + - dtype conversion kernel (float32 -> bf16) + - subtraction, shift, view, cast, bitwise_or kernels + All fused into a single kernel. + """ + pid = tl.program_id(0) + + # Calculate row and column from linear index + linear_idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + row_idx = linear_idx // n_cols + col_idx = linear_idx % n_cols + + mask = linear_idx < (n_rows * n_cols) + + # Compute actual memory offset using strides for expert_ids + ids_offset = row_idx * stride_ids_row + col_idx * stride_ids_col + weights_offset = row_idx * stride_weights_row + col_idx * stride_weights_col + + # Load expert IDs and compute local IDs + expert_ids = tl.load(expert_ids_ptr + ids_offset, mask=mask) + local_ids = expert_ids - local_expert_offset + + # Load weights as float32 and convert to bf16 bits + weights_f32 = tl.load(expert_weights_ptr + weights_offset, mask=mask) + # Convert to bf16, then reinterpret as int16 + weights_bf16 = weights_f32.to(tl.bfloat16) + weights_int16 = weights_bf16.to(tl.int16, bitcast=True) + weights_int32 = weights_int16.to(tl.int32) & 0xFFFF + + # Pack: (local_id << 16) | weight_bits + packed = (local_ids.to(tl.int32) << 16) | weights_int32 + + # Output is always contiguous + tl.store(output_ptr + linear_idx, packed, mask=mask) + + +def pack_topk_ids_triton( + expert_ids: torch.Tensor, + expert_weights: torch.Tensor, + local_expert_offset: int, + output: torch.Tensor = None, +) -> torch.Tensor: + """ + Pack expert IDs with actual weights into packed format using a fused Triton kernel. + + This fused kernel handles: + - Non-contiguous input tensors via strides + - float32 -> bf16 conversion for weights + - Packing: (expert_id - offset) << 16 | weight_bf16_bits + + Args: + expert_ids: [total_tokens, top_k] expert indices (int32 or int64), can be non-contiguous + expert_weights: [total_tokens, top_k] routing weights (float32), can be non-contiguous + local_expert_offset: offset to subtract from global expert IDs + output: optional pre-allocated output tensor [total_tokens, top_k] int32 + + Returns: + packed_topk_ids: [total_tokens, top_k] int32 where each element is + ((expert_id - offset) << 16) | (weight_bf16 as int16) + """ + assert expert_ids.ndim == 2 + assert expert_weights.ndim == 2 + assert expert_ids.shape == expert_weights.shape + n_rows, n_cols = expert_ids.shape + + if output is None: + output = torch.empty( + n_rows, n_cols, dtype=torch.int32, device=expert_ids.device + ) + + n_elements = n_rows * n_cols + BLOCK_SIZE = 1024 + grid = (triton.cdiv(n_elements, BLOCK_SIZE),) + + _pack_topk_ids_kernel[grid]( + expert_ids, + expert_weights, + output, + local_expert_offset, + expert_ids.stride(0), + expert_ids.stride(1), + expert_weights.stride(0), + expert_weights.stride(1), + n_rows, + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output + + +def calculate_moe_tflops( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + num_experts: int, + top_k: int, + time_ms: float, +) -> float: + """ + Calculate TFLOPS for MOE operation. + + MOE computation involves: + 1. First GEMM: [num_tokens, hidden_size] x [num_experts, hidden_size, 2*intermediate_size] + 2. Activation function (SwiGLU gate) + 3. Second GEMM: [num_tokens, intermediate_size] x [num_experts, intermediate_size, hidden_size] + + For each token, we only compute for top_k experts. + + Args: + num_tokens: Number of input tokens + hidden_size: Hidden dimension size + intermediate_size: Intermediate FFN dimension + num_experts: Total number of experts + top_k: Number of experts per token + time_ms: Execution time in milliseconds + + Returns: + TFLOPS value + """ + _ = num_experts # kept for backward compatibility + + # FLOPS per token per expert + flops_per_token_per_expert = ( + 2 * hidden_size * 2 * intermediate_size # First GEMM + + 2 * intermediate_size * hidden_size # Second GEMM + ) + + total_flops = num_tokens * top_k * flops_per_token_per_expert + tflops = total_flops / (time_ms * 1e-3) / 1e12 # Convert to TFLOPS + return tflops + + +def calculate_moe_kernel_bandwidth( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + num_experts: int, + top_k: int, + time_ms: float, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + input_format: Optional[str] = None, + weight_format: Optional[str] = None, + routing_logits_dtype: Optional[torch.dtype] = torch.float32, + active_experts: Optional[int] = None, + verbose: int = 0, +) -> float: + """ + Calculate memory bandwidth for MOE kernel operation in TB/sec. + + Args: + num_tokens: Number of input tokens + hidden_size: Hidden dimension size + intermediate_size: Intermediate FFN dimension + num_experts: Total number of experts + top_k: Number of experts per token + time_ms: Execution time in milliseconds + input_dtype: Data type of input + weight_dtype: Data type of weights + input_format: Override for input representation; None uses dtype.itemsize + weight_format: Override for weight representation; None uses dtype.itemsize + routing_logits_dtype: Dtype for routing logits memory accounting (default float32) + active_experts: Number of active experts (if known) + verbose: Verbosity level + + Returns: + Bandwidth in TB/sec + """ + + # Get effective byte sizes + def get_effective_bytes( + dtype: torch.dtype, fmt: Optional[str], is_weight: bool = False + ) -> float: + if fmt == "nvfp4": + # 1 e4m3 + 1 e4m3 scale per 16-element block + return 0.5 + 1 / 16 + elif fmt == "mxfp4": + # 1 e2m1 + 1 ue8m0 scale per 32-element block + return 0.5 + 1 / 32 + elif fmt == "fp8": + # 1 e4m3 + return 1.0 + elif fmt == "fp8_block_scale": + granularity = 128 * 128 if is_weight else 128 + # 1 e4m3 + 1 float32 scale factor per block + return 1.0 + (4 / granularity) + return dtype.itemsize + + input_bytes_per_element = get_effective_bytes(input_dtype, input_format) + weight_bytes_per_element = get_effective_bytes( + weight_dtype, weight_format, is_weight=True + ) + + # Input memory: hidden states + routing logits + routing_logits_bytes = ( + 0 if routing_logits_dtype is None else routing_logits_dtype.itemsize + ) + input_bytes = ( + # Count hidden states once; kernels typically reuse inputs for multiple experts + num_tokens * hidden_size * input_bytes_per_element + + num_tokens * num_experts * routing_logits_bytes + ) + + # Weight memory + weight_bytes_per_expert = ( + 2 * intermediate_size * hidden_size * weight_bytes_per_element # gemm1 + + hidden_size * intermediate_size * weight_bytes_per_element # gemm2 + ) + if active_experts is not None: + num_active_experts = active_experts + else: + num_active_experts = min(num_experts, top_k * num_tokens) + if verbose >= 2: + print(f"[VVERBOSE] num_active_experts = {num_active_experts}") + + weight_bytes = num_active_experts * weight_bytes_per_expert + + # Output memory (typically full precision) + output_bytes = num_tokens * hidden_size * input_dtype.itemsize + + total_bytes = input_bytes + weight_bytes + output_bytes + tb_per_sec = total_bytes / (time_ms * 1e-3) / 1e12 # Convert to TB/sec + return tb_per_sec + + +def compute_routing( + router_logits: torch.Tensor, + top_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute routing weights and selected experts using standard top-k routing. + + Args: + router_logits: [num_tokens, num_experts] routing scores + top_k: Number of experts to select per token + + Returns: + Tuple of (routing_weights, selected_experts) + - routing_weights: [num_tokens, top_k] normalized routing weights + - selected_experts: [num_tokens, top_k] selected expert indices + """ + routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.float() + return routing_weights, selected_experts + + +def add_common_moe_args(parser: argparse.ArgumentParser) -> None: + """ + Add common MoE CLI arguments to a parser. + + This adds arguments shared between moe.py and moe_comm.py: + - num_tokens, hidden_size, num_experts, top_k, input_dtype + + In constrast to moe.py, intermediate_size is only optional for moe_comm.py, + hence not counted as a common argument. + + Args: + parser: ArgumentParser to add arguments to + """ + parser.add_argument( + "--num_tokens", + type=int, + required=True, + help="Number of tokens per rank (local batch size).", + ) + parser.add_argument( + "--hidden_size", + type=int, + required=True, + help="Hidden dimension size.", + ) + parser.add_argument( + "--num_experts", + type=int, + required=True, + help="Total number of experts.", + ) + parser.add_argument( + "--top_k", + type=int, + required=True, + help="Number of experts to route each token to.", + ) + parser.add_argument( + "--input_dtype", + type=str, + required=False, + default="bfloat16", + help="Data type of input hidden states.", + ) + + +def process_fp8_weight_layout( + tensor: torch.Tensor, + use_shuffled_weight: bool, + weight_layout: int, + epilogue_tile_m: int = 64, +) -> torch.Tensor: + """ + Process FP8 weight tensor with optional shuffling and layout conversion. + + This encapsulates the common pattern of: + 1. Converting to uint8 view + 2. Applying shuffle_matrix_a + 3. Optionally converting to BlockMajorK layout + + Args: + tensor: FP8 weight tensor (float8_e4m3fn) + use_shuffled_weight: Whether to apply weight shuffling + weight_layout: Weight layout (0=MajorK, 2=BlockMajorK) + epilogue_tile_m: Tile size for shuffle operation (default 64) + + Returns: + Processed tensor (as float8_e4m3fn view) + """ + if use_shuffled_weight: + # Shuffle the weight matrix + tensor = shuffle_matrix_a(tensor.view(torch.uint8), epilogue_tile_m) + + # Apply block layout conversion if needed + if weight_layout == WeightLayout.BlockMajorK: + block_k = 128 + tensor = convert_to_block_layout(tensor, block_k) + + return tensor.view(torch.float8_e4m3fn) + + +def create_moe_output_scale_scalars( + num_experts: int, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Create output scale scalar tensors for MoE kernels. + + These are used by FP4 and FP8 per-tensor scale MoE kernels. + + Args: + num_experts: Number of experts + device: Device to create tensors on + + Returns: + Tuple of (output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar) + All tensors are float32 with shape [num_experts], initialized to 1.0 + """ + output1_scale_scalar = torch.ones(num_experts, device=device, dtype=torch.float32) + output1_scale_gate_scalar = torch.ones( + num_experts, device=device, dtype=torch.float32 + ) + output2_scale_scalar = torch.ones(num_experts, device=device, dtype=torch.float32) + return output1_scale_scalar, output1_scale_gate_scalar, output2_scale_scalar + + +def quantize_and_pack_nvfp4( + tensor: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Quantize tensor to NVFP4 and pack into communication-ready format. + + This wraps quantize_fp4 and performs the reshaping needed for MoE kernels: + - Packs 2 FP4 values into 1 byte (uint8) + - Reshapes block scales appropriately + + Args: + tensor: Input tensor [num_tokens, hidden_size] + global_scale: Optional pre-computed global scale. If None, computed from tensor. + use_ue8m0: Whether to use UE8M0 format + is_sf_swizzled_layout: Whether to use swizzled layout for scale factors. + Use False for activations, True for weights. + + Returns: + Tuple of (quantized_packed, block_scales, global_scale) + - quantized_packed: uint8 tensor [num_tokens, hidden_size // 2] + - block_scales: float8_e4m3fn tensor [num_tokens, hidden_size // 16] + - global_scale: float32 scalar tensor + """ + sf_vec_size = 16 + num_tokens, hidden_size = tensor.shape + + # Quantize using the standard FP4 quantization + quantized, block_scales, global_scale = quantize_fp4( + tensor, global_scale, use_ue8m0, is_sf_swizzled_layout + ) + + # Pack 2 FP4 values into 1 byte + quantized_packed = quantized.view(torch.uint8).reshape(num_tokens, hidden_size // 2) + + # Reshape block scales + block_scales_reshaped = block_scales.view(torch.float8_e4m3fn).reshape( + num_tokens, hidden_size // sf_vec_size + ) + + # Validate scale shape + expected_scale_elems = (num_tokens * hidden_size) // sf_vec_size + assert block_scales_reshaped.numel() == expected_scale_elems, "Invalid scale shape" + + return quantized_packed, block_scales_reshaped, global_scale diff --git a/benchmarks/samples/sample_testlist.txt b/benchmarks/samples/sample_testlist.txt index 64593c226e..03a33f33ec 100644 --- a/benchmarks/samples/sample_testlist.txt +++ b/benchmarks/samples/sample_testlist.txt @@ -49,6 +49,20 @@ --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --quantized_input --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_weights_quantized" --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 8 --top_k 2 --cutlass_variant base --input_dtype float16 --tp_size 2 --tp_rank 0 --ep_size 4 --ep_rank 0 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_ep_tp" +## MoE Communication (requires mpirun, e.g.: mpirun -np 8 python benchmarks/flashinfer_benchmark.py ...) +# Basic A2A dispatch+combine without quantization +#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 -vv --generate_repro_command --case_tag "moe_a2a_basic" +# With FP8 per-tensor quantization +#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype fp8 -vv --generate_repro_command --case_tag "moe_a2a_fp8" +# With NVFP4 block-scale quantization +#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype nvfp4 -vv --generate_repro_command --case_tag "moe_a2a_nvfp4" +# With FP8 block-scale quantization +#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype fp8_block_scale -vv --generate_repro_command --case_tag "moe_a2a_fp8_block_scale" +# With real MoE kernel (NVFP4) +#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype nvfp4 --real_math --intermediate_size 18432 --per_phase_timing -vv --generate_repro_command --case_tag "moe_a2a_nvfp4_real_math" +# With real MoE kernel (FP8 block-scale) +#--routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 --quant_dtype fp8_block_scale --real_math --intermediate_size 18432 --per_phase_timing -vv --generate_repro_command --case_tag "moe_a2a_fp8_bs_real_math" + ## RMSNorm # Basic RMSNorm with 2D input shape --routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_llama_hidden" diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 97a980c0d6..cecf4efb7a 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -714,16 +714,20 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { public: static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; - Fp8BlockScaleLauncher(TensorView const& routing_logits, Optional const& routing_bias, - TensorView const& hidden_states, TensorView const& hidden_states_scale, - TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale, - TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale) - : FusedMoeLauncher(Optional(routing_logits), routing_bias, hidden_states, - gemm1_weights, Optional(), Optional(), - gemm2_weights, Optional()), + Fp8BlockScaleLauncher(Optional const& routing_logits, + Optional const& routing_bias, TensorView const& hidden_states, + TensorView const& hidden_states_scale, TensorView const& gemm1_weights, + TensorView const& gemm1_weights_scale, TensorView const& gemm2_weights, + TensorView const& gemm2_weights_scale, TensorView const& expert_indices, + TensorView const& expert_weights) + : FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights, + Optional(), Optional(), gemm2_weights, + Optional()), hidden_states_scale(hidden_states_scale), gemm1_weights_scale(gemm1_weights_scale), - gemm2_weights_scale(gemm2_weights_scale) {} + gemm2_weights_scale(gemm2_weights_scale), + expert_indices(expert_indices), + expert_weights(expert_weights) {} void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, @@ -752,6 +756,18 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { } void check_routing() const override { + // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr + if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) { + // Pre-computed routing: expert_indices is a packed tensor + // Format: (expert_id << 16) | (weight_bf16.view(int16)) + TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D."; + TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0)) + << "expert_indices and hidden_states must have same number of tokens."; + TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) + << "expert_indices dim1 must match top_k."; + TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32."; + } + FusedMoeLauncher::check_routing_common(); if (args->n_group != 0) { @@ -801,15 +817,30 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { } args->mUseDeepSeekFp8 = true; - args->routing_logits = static_cast(routing_logits.value().data_ptr()); + // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr + bool has_precomputed_indices = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; + if (has_precomputed_indices) { + // Use expert_indices directly + workspace.routing_expert_indexes = + static_cast(const_cast(expert_indices.data_ptr())); + } else { + // Use routing_logits directly + args->routing_logits = static_cast(routing_logits.value().data_ptr()); + } // Set expert weights dtype based on routing bias auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; - - expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); - workspace.expert_weights = expert_weights.data_ptr(); + // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr + bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0; + if (!has_precomputed_weights) { + // Allocate expert_weights buffer for routing output + FusedMoeLauncher::expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr(); + } else { + workspace.expert_weights = const_cast(expert_weights.data_ptr()); + } } void check_moe() const override { @@ -908,8 +939,53 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { TensorView gemm2_weights_scale; Tensor gemm1_output_scale; Tensor activation_output_scale; + TensorView expert_indices; + TensorView expert_weights; public: + // Override to handle pre-computed routing + Array run(int64_t moe_tactic, bool enable_pdl = true, + bool use_routing_scales_on_input = false, + bool use_deep_seek_fp8 = false) override { + check_routing(); + prepare_routing(); + + cudaStream_t routing_stream = get_stream(hidden_states.device()); + tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); + + // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr + bool use_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; + // When using pre-computed routing, pass nullptr as routing_logits to tell the + // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes + routing_runner.run( + use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens, + args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, + args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes, + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*permuted_idx_to_expanded_idx.data_ptr()*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), workspace.expert_weights, + static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args->mDtypeElt, mRoutingBiasDtype, + use_routing_scales_on_input, use_deep_seek_fp8, + static_cast(routing_method_type), routing_stream); + + check_moe(); + prepare_moe(moe_tactic); + + cudaStream_t moe_stream = get_stream(hidden_states.device()); + moe_runner->run(*args, workspace, hidden_states.device().device_id, moe_stream, moe_tactic, + enable_pdl); + + if (args->do_finalize) { + return {output}; + } + return {gemm2_output, expanded_idx_to_permuted_idx}; + } + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, int64_t num_tokens, bool use_shuffled_weight, @@ -1565,19 +1641,34 @@ Tensor trtllm_fp8_per_tensor_scale_moe( } Tensor trtllm_fp8_block_scale_moe( - TensorView routing_logits, Optional routing_bias, TensorView hidden_states, - TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, - TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, - int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, - int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, - Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, bool enable_pdl, Array config_index) { + Optional routing_logits, TensorView expert_indices, TensorView expert_weights, + Optional routing_bias, TensorView hidden_states, TensorView hidden_states_scale, + TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights, + TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k, + Optional n_group, Optional topk_group, int64_t intermediate_size, + int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, + int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl, + Array config_index) { // Basic type validation auto dtype = hidden_states.dtype(); - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; - } else { - TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; + + // Either routing_logits or expert_indices must be provided + // expert_indices is a packed tensor: (expert_id << 16) | (weight_bf16.view(int16)) + bool use_routing_logits = routing_logits.has_value(); + // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr + bool use_precomputed_routing = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; + + TVM_FFI_ICHECK(use_routing_logits || use_precomputed_routing) + << "Either routing_logits or expert_indices must be provided."; + + if (use_routing_logits) { + if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { + TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) + << "routing_logits must be float."; + } else { + TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16) + << "routing_logits must be bfloat16."; + } } TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) << "FP8 block scale MoE: hidden_states must be fp16, bf16, or fp8."; @@ -1621,7 +1712,7 @@ Tensor trtllm_fp8_block_scale_moe( // Create and initialize launcher for this tile size auto launcher = std::make_unique( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, - gemm1_weights_scale, gemm2_weights, gemm2_weights_scale); + gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, expert_indices, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout); @@ -1648,7 +1739,7 @@ Tensor trtllm_fp8_block_scale_moe( } Array trtllm_fp4_block_scale_moe( - Optional routing_logits, TensorView topk_ids, TensorView expert_weights, + Optional routing_logits, TensorView expert_indices, TensorView expert_weights, Optional routing_bias, TensorView hidden_states, Optional hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, Optional gemm1_bias, @@ -1762,7 +1853,7 @@ Array trtllm_fp4_block_scale_moe( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit, gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, - output2_scales_scalar, topk_ids, expert_weights); + output2_scales_scalar, expert_indices, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, /*weight_layout=*/0, gated_act_type, mDtypeAct, mDtypeWeights); diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index a34d37f149..f7886fe400 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -29,6 +29,7 @@ trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_routed_moe, trtllm_fp8_block_scale_moe, + trtllm_fp8_block_scale_routed_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_bf16_moe, trtllm_mxint4_block_scale_moe, @@ -54,6 +55,7 @@ "trtllm_fp4_block_scale_moe", "trtllm_fp4_block_scale_routed_moe", "trtllm_fp8_block_scale_moe", + "trtllm_fp8_block_scale_routed_moe", "trtllm_fp8_per_tensor_scale_moe", "trtllm_mxint4_block_scale_moe", "fused_topk_deepseek", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 45d5d11bb0..0e9d643b4c 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1131,6 +1131,8 @@ def forward( ) moe_op.trtllm_fp8_block_scale_moe( routing_logits, + topk_ids, + expert_weights, kwargs["routing_bias"], hidden_states, current_hidden_states_scale, @@ -1542,7 +1544,9 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( mutates_args=(""), ) def trtllm_fp8_block_scale_moe_op( - routing_logits: torch.Tensor, + routing_logits: Optional[torch.Tensor], + topk_ids: Optional[torch.Tensor], + expert_weights: Optional[torch.Tensor], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, @@ -1565,6 +1569,16 @@ def trtllm_fp8_block_scale_moe_op( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, ) -> torch.Tensor: + # Determine routing mode: compute from logits or use pre-computed + if routing_logits is None: + assert topk_ids is not None, ( + "either topk_ids or routing_logits must be provided." + ) + assert topk_ids.dtype == torch.int32, "topk_ids must be an int32 tensor." + routing_dtype = torch.bfloat16 + else: + routing_dtype = routing_logits.dtype + if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1579,12 +1593,22 @@ def trtllm_fp8_block_scale_moe_op( output = torch.empty( num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device ) - topk_ids = torch.empty( - num_tokens, top_k, dtype=torch.int32, device=hidden_states.device - ) - expert_weights = torch.empty( - num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device - ) + if routing_logits is not None: + # When routing_logits is provided, we must pass topk_ids/expert_weights with no allocation + topk_ids = torch.empty(0, dtype=torch.int32, device=hidden_states.device) + expert_weights = torch.empty( + 0, dtype=routing_dtype, device=hidden_states.device + ) + else: + # When routing_logits is provided, we either have topk_ids/expert_weights, + # packed into a single tensor as topk_id + # or have them individually as topk_ids and expert_weights respectively + topk_ids = topk_ids + expert_weights = ( + expert_weights + if expert_weights is not None + else torch.empty(0, dtype=routing_dtype, device=hidden_states.device) + ) dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights @@ -1634,6 +1658,8 @@ def trtllm_fp8_block_scale_moe_op( # Call the C++ function for block scale MoE result = moe_op.trtllm_fp8_block_scale_moe( routing_logits, + topk_ids, + expert_weights, routing_bias, hidden_states, hidden_states_scale, @@ -1661,7 +1687,9 @@ def trtllm_fp8_block_scale_moe_op( @register_fake_op("flashinfer::trtllm_fp8_block_scale_moe") def _fake_trtllm_fp8_block_scale_moe( - routing_logits: torch.Tensor, + routing_logits: Optional[torch.Tensor], + topk_ids: Optional[torch.Tensor], + expert_weights: Optional[torch.Tensor], routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, hidden_states_scale: torch.Tensor, @@ -1682,6 +1710,7 @@ def _fake_trtllm_fp8_block_scale_moe( use_shuffled_weight: bool = False, weight_layout: int = 0, enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -2293,6 +2322,97 @@ def trtllm_fp8_block_scale_moe( ) return get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe( routing_logits, + None, # topk_ids - will be computed from routing_logits + None, # expert_weights - will be computed from routing_logits + routing_bias, + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + output, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + routing_method_type, + use_shuffled_weight, + weight_layout, + enable_pdl, + tune_max_num_tokens, + ) + + +@flashinfer_api +def trtllm_fp8_block_scale_routed_moe( + topk_ids: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + hidden_states_scale: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_weights_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_weights_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + routing_method_type: int = 0, + use_shuffled_weight: bool = False, + weight_layout: int = 0, + enable_pdl: Optional[bool] = None, + output: Optional[torch.Tensor] = None, + tune_max_num_tokens: int = 8192, +) -> torch.Tensor: + """FP8 block scale MoE operation with pre-computed routing (packed format). + + This function is used when routing decisions have already been computed + and packed into a single tensor. This is useful for: + - CUDA Graph capture (avoids CPU-GPU sync from routing_logits processing) + - Distributed MoE where routing is computed elsewhere + + Args: + topk_ids: [seq_len, top_k] tensor of packed expert indices and weights (int32). + Format: (expert_id << 16) | (weight_bf16.view(int16)) + Can be created as: (topk_ids.int32 << 16) | expert_weights.bfloat16.view(int16) + routing_bias: [num_experts] tensor of routing bias (can be None) + hidden_states: [seq_len, hidden_size] tensor of input hidden states + hidden_states_scale: [hidden_size//128, seq_len] tensor of hidden states block scales + gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights + gemm1_weights_scale: [num_experts, 2*intermediate_size//128, hidden_size//128] tensor of first layer block scales + gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights + gemm2_weights_scale: [num_experts, hidden_size//128, intermediate_size//128] tensor of second layer block scales + num_experts: Total number of experts + top_k: Number of experts to route to per token + n_group: Number of expert groups + topk_group: Number of groups to consider for top-k routing + intermediate_size: Size of intermediate layer + local_expert_offset: Offset of local experts in global expert space + local_num_experts: Number of experts handled by this device + routed_scaling_factor: Scaling factor for routing + routing_method_type: Type of routing method to use (default: 0) + use_shuffled_weight: Whether to use shuffled weights + weight_layout: Weight layout (0 = MajorK, 1 = BlockMajorK) + enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + output (Optional[torch.Tensor]): shape [seq_len, hidden_size] + Optional inplace output tensor. + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) + Returns: + torch.Tensor: Output tensor of shape [seq_len, hidden_size] + """ + return get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe( + None, # routing_logits + topk_ids, + None, # expert_weights routing_bias, hidden_states, hidden_states_scale, diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index fb3feba4b7..7a47444081 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -27,6 +27,8 @@ from flashinfer.fused_moe import ( trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_routed_moe, + trtllm_fp8_block_scale_moe, + trtllm_fp8_block_scale_routed_moe, ) from flashinfer.utils import device_support_pdl @@ -245,3 +247,140 @@ def test_trtllm_gen_routed_fused_moe( # mismatch percentage mismatch_pct = (~mask).float().mean().item() * 100 assert mismatch_pct < 6, f"Mismatch percentage is {mismatch_pct:.2f}" + + +@pytest.mark.parametrize("num_tokens", [8, 64]) +@pytest.mark.parametrize("hidden_size", [1024, 2048]) +@pytest.mark.parametrize("intermediate_size", [1024, 2048]) +@pytest.mark.parametrize("num_experts", [8, 16]) +@pytest.mark.parametrize("top_k", [2, 4]) +@pytest.mark.parametrize( + "routing_method_type", + [ + RoutingMethodType.Renormalize, + ], +) +def test_trtllm_gen_fp8_routed_fused_moe( + num_tokens: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + num_experts: int, + routing_method_type: RoutingMethodType, +): + """Test FP8 block scale routed MoE matches standard routing.""" + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + torch.manual_seed(42) + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + + # Generate random routing logits for reference + routing_logits = torch.rand(num_tokens, num_experts, device=device).to( + torch.bfloat16 + ) + + # Generate random hidden states in FP8 + hidden_states_bf16 = ( + torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1 + ) + hidden_states = hidden_states_bf16.to(torch.float8_e4m3fn) + + # Generate block scales for hidden states: [hidden_size // 128, num_tokens] + hidden_states_scale = torch.ones( + hidden_size // 128, num_tokens, device=device, dtype=torch.float32 + ) + + # Generate FP8 weights + gemm1_weights = torch.randn( + num_experts, 2 * intermediate_size, hidden_size, device=device + ).to(torch.float8_e4m3fn) + gemm2_weights = torch.randn( + num_experts, hidden_size, intermediate_size, device=device + ).to(torch.float8_e4m3fn) + + # Generate block scales for weights + gemm1_weights_scale = torch.ones( + num_experts, + 2 * intermediate_size // 128, + hidden_size // 128, + device=device, + dtype=torch.float32, + ) + gemm2_weights_scale = torch.ones( + num_experts, + hidden_size // 128, + intermediate_size // 128, + device=device, + dtype=torch.float32, + ) + + # Run reference with routing_logits + reference_output = trtllm_fp8_block_scale_moe( + routing_logits, + None, # routing_bias + hidden_states, + hidden_states_scale, + gemm1_weights, + gemm1_weights_scale, + gemm2_weights, + gemm2_weights_scale, + num_experts, + top_k, + None, # n_group + None, # topk_group + intermediate_size, + 0, # local_expert_offset + num_experts, + None, # routed_scaling_factor + routing_method_type.value, + False, # use_shuffled_weight + 0, # weight_layout + enable_pdl, + ).to(torch.float) + + # Compute routing using reference implementation + permute_info, expert_weights_ref = routing_reference_renormalize( + routing_logits, top_k, num_experts, 8 + ) + topk_ids = permute_info["topKIndices"].to(torch.int32) + expert_weights = expert_weights_ref.view(num_tokens, num_experts)[ + torch.arange(num_tokens, device=device).unsqueeze(1), topk_ids + ].to(torch.bfloat16) + + # Pack topk_ids and expert_weights into single tensor + # Format: (expert_id << 16) | (weight_bf16.view(int16)) + packed_topk_ids = (topk_ids << 16) | expert_weights.view(torch.int16).to( + torch.int32 + ) + + # Run with pre-computed routing (packed format) + output = trtllm_fp8_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + gemm1_weights=gemm1_weights, + gemm1_weights_scale=gemm1_weights_scale, + gemm2_weights=gemm2_weights, + gemm2_weights_scale=gemm2_weights_scale, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=routing_method_type.value, + use_shuffled_weight=False, + weight_layout=0, + enable_pdl=enable_pdl, + ).to(torch.float) + + mask = torch.isclose(output, reference_output, rtol=1e-2, atol=1e-2) + + # mismatch percentage + mismatch_pct = (~mask).float().mean().item() * 100 + assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%"