diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 8ff7036dec..f9314c98fd 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -2,6 +2,7 @@ from typing import Optional, Literal import torch import numpy as np +from functools import partial from flashinfer import ( RoutingMethodType, ActivationType, @@ -10,6 +11,7 @@ ) from flashinfer.fused_moe import ( trtllm_fp4_block_scale_moe, + trtllm_mxint4_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_block_scale_moe, WeightLayout, @@ -23,13 +25,32 @@ FLOAT4_E2M1_MAX = 6.0 -def fp8_quantize(x): +def fp8_quantize(x) -> tuple[torch.Tensor, torch.Tensor]: max = x.abs().max().float() scale = FLOAT8_E4M3_MAX / max x = (x * scale).to(torch.float8_e4m3fn) return x, 1.0 / scale +def mxint4_quantize( + x: torch.Tensor, sf_vec_size: int = 32 +) -> tuple[torch.Tensor, torch.Tensor]: + x_reshaped = x.reshape(-1, sf_vec_size) + x_max = x_reshaped.max(dim=-1, keepdim=True)[0].to(torch.float32) + x_min = x_reshaped.min(dim=-1, keepdim=True)[0].to(torch.float32) + x_max = x_max * 8.0 / 7.0 + amax = torch.where(x_max > -x_min, x_max, -x_min) + scales = amax / 8.0 + x_scaled = x_reshaped * scales.reciprocal() + x_int8 = ( + x_scaled.round().clamp(-8, 7).to(torch.int8).reshape(-1, sf_vec_size // 2, 2) + ) + x_int4 = (x_int8[..., 0] & 0x0F) | ((x_int8[..., 1] & 0x0F) << 4) + return x_int4.reshape(*x.shape[:-1], x.shape[-1] // 2).view( + torch.uint8 + ), scales.reshape(-1, sf_vec_size) + + def bench_trtllm_gen_fused_moe_autotuner_fp8( tune_max_num_tokens: Optional[int], quant_mode: Literal["Fp8-Per-Tensor", "Fp8-Block"], @@ -40,7 +61,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( top_k: int, warmups: int, iterations: int, - activation_type: ActivationType, + activation_type: int, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -50,7 +71,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( torch.bfloat16 ) - routing_bias = torch.randn(num_experts, device="cuda", dtype=torch.bfloat16) + routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16) w13 = torch.randn( num_experts, intermediate_size * 2, hidden_size, device=device ).to(torch.bfloat16) @@ -99,67 +120,74 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( ) if is_block_scale: - if activation_type != ActivationType.Swiglu: - raise ValueError( - "Only Swiglu activation is supported for FP8 block scale MoE." - ) - fn = lambda: trtllm_fp8_block_scale_moe( - routing_logits, - routing_bias, - hidden_states, - hidden_states_scale, - w13, - w13_scale, - w2, - w2_scale, - num_experts, - top_k, - 8, # n_group - 4, # topk_group - intermediate_size, - 0, # local_expert_offset - num_experts, - 2.5, # routed_scaling_factor - RoutingMethodType.DeepSeekV3.value, - True, # use_shuffled_weight - WeightLayout.BlockMajorK.value, # weight_layout + assert activation_type == ActivationType.Swiglu.value, ( + "Only Swiglu activation is supported for FP8 block scale MoE." + ) + fn = partial( + trtllm_fp8_block_scale_moe, + routing_logits=routing_logits, + routing_bias=routing_bias, + num_experts=num_experts, + top_k=top_k, + n_group=8, + topk_group=4, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=2.5, + routing_method_type=RoutingMethodType.DeepSeekV3.value, + use_shuffled_weight=False, + weight_layout=WeightLayout.MajorK.value, # weight_layout enable_pdl=enable_pdl, tune_max_num_tokens=num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, ) else: - fn = lambda: trtllm_fp8_per_tensor_scale_moe( - routing_logits, - None, # routing_bias - hidden_states, - w13, - output1_scale_scalar, - output1_scales_gate_scalar, - w2, - output2_scale_scalar, - num_experts, - top_k, - None, # n_group - None, # topk_group - intermediate_size, - 0, # local_expert_offset - num_experts, - 1.0, # routed_scaling_factor - False, # use_routing_scales_on_input - RoutingMethodType.TopK.value, - enable_pdl, - num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, - activation_type.value, + fn = partial( + trtllm_fp8_per_tensor_scale_moe, + routing_logits=routing_logits.to(torch.bfloat16), + routing_bias=None, + output1_scales_scalar=output1_scale_scalar, + output1_scales_gate_scalar=output1_scales_gate_scalar, + output2_scales_scalar=output2_scale_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=1.0, + use_routing_scales_on_input=False, + routing_method_type=RoutingMethodType.TopK.value, + enable_pdl=enable_pdl, + tune_max_num_tokens=num_tokens + if tune_max_num_tokens is None + else tune_max_num_tokens, + activation_type=activation_type, ) + input_kwargs = { + "hidden_states": hidden_states, + "gemm1_weights": w13, + "gemm2_weights": w2, + } + if is_block_scale: + input_kwargs["hidden_states_scale"] = hidden_states_scale + input_kwargs["gemm1_weights_scale"] = w13_scale + input_kwargs["gemm2_weights_scale"] = w2_scale def bench(do_autotune): with autotune(do_autotune): - fn() + fn(**input_kwargs) ms_list = bench_gpu_time( fn, dry_run_iters=warmups, repeat_iters=iterations, + enable_cupti=True, + use_cuda_graph=True, + input_kwargs=input_kwargs, + cold_l2_cache=True, ) median_ms = np.median(ms_list) return median_ms @@ -182,7 +210,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( top_k: int, warmups: int, iterations: int, - activation_type: ActivationType, + activation_type: int, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -242,10 +270,9 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( w13_global_scale = 1.0 / 448.0 / 6.0 w2_global_scale = 1.0 / 448.0 / 6.0 else: - if activation_type == ActivationType.Relu2: - raise ValueError( - "Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode" - ) + assert activation_type != ActivationType.Relu2.value, ( + "Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode" + ) w13, w13_scale = fp4_quantize( w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True ) @@ -272,46 +299,152 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( output2_scale_scalar = torch.tensor( [hidden_states_global_scale * w2_global_scale] * num_experts, device=device ) - fn = lambda: trtllm_fp4_block_scale_moe( - routing_logits, - None, # routing_bias - hidden_states, - hidden_states_scale, - w13, - w13_scale, - bias13, - None, # gemm1_alpha - None, # gemm1_beta - None, # gemm1_clamp_limit - w2, - w2_scale, - bias2, - output1_scale_scalar, - output1_scale_gate_scalar, - output2_scale_scalar, + fn = partial( + trtllm_fp4_block_scale_moe, + routing_logits=routing_logits, + routing_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + output1_scale_scalar=output1_scale_scalar, + output1_scale_gate_scalar=output1_scale_gate_scalar, + output2_scale_scalar=output2_scale_scalar, + num_experts=num_experts, + top_k=top_k, + n_group=None, + topk_group=None, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=RoutingMethodType.Renormalize.value, + do_finalize=True, + enable_pdl=enable_pdl, + activation_type=activation_type, + output=None, + tune_max_num_tokens=num_tokens + if tune_max_num_tokens is None + else tune_max_num_tokens, + ) + + input_kwargs = { + "hidden_states": hidden_states, + "hidden_states_scale": hidden_states_scale, + "gemm1_weights": w13, + "gemm1_weights_scale": w13_scale, + "gemm2_weights": w2, + "gemm2_weights_scale": w2_scale, + "gemm1_bias": bias13, + "gemm2_bias": bias2, + } + + def bench(do_autotune): + with autotune(do_autotune): + fn(**input_kwargs) + ms_list = bench_gpu_time( + fn, + dry_run_iters=warmups, + repeat_iters=iterations, + enable_cupti=True, + use_cuda_graph=True, + input_kwargs=input_kwargs, + cold_l2_cache=True, + ) + median_ms = np.median(ms_list) + return median_ms + + ms = bench(do_autotune=False) + ms_tuned = bench(do_autotune=True) + print( + f"num tokens: {num_tokens}, num experts: {num_experts}, hidden size: {hidden_size}, intermediate size: {intermediate_size}, top k: {top_k}" + ) + print(f"No autotune: {ms:.3f} ms; with autotune: {ms_tuned:.3f} ms") + + +def bench_trtllm_gen_fused_moe_autotuner_mxint4( + tune_max_num_tokens: Optional[int], + quant_mode: Literal["MxInt4xBf16"], + num_tokens: int, + num_experts: int, + hidden_size: int, + intermediate_size: int, + top_k: int, + warmups: int, + iterations: int, + activation_type: int, +): + device = torch.device("cuda:0") + enable_pdl = device_support_pdl(device) + routing_logits = torch.rand(num_tokens, num_experts, device=device).float() + routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16) + hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( + torch.bfloat16 + ) + + w13 = torch.randn( + num_experts, intermediate_size * 2, hidden_size, device=device + ).to(torch.bfloat16) + w2 = torch.randn(num_experts, hidden_size, intermediate_size, device=device).to( + torch.bfloat16 + ) + w13, w13_scale = mxint4_quantize(w13, 32) + w13_scale = w13_scale.to(torch.bfloat16).reshape( num_experts, - top_k, - None, # n_group - None, # topk_group - intermediate_size, - 0, # local_expert_offset + 2 * intermediate_size, + hidden_size // 32, + ) + w2, w2_scale = mxint4_quantize(w2, 32) + w2_scale = w2_scale.to(torch.bfloat16).reshape( num_experts, - None, # routed_scaling_factor - RoutingMethodType.Renormalize.value, - True, - enable_pdl, - activation_type.value, # act_type - None, - num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + hidden_size, + intermediate_size // 32, ) + assert activation_type == ActivationType.Swiglu, ( + "only SwiGlu activation is supported for MxInt4 MoE currently" + ) + fn = partial( + trtllm_mxint4_block_scale_moe, + routing_logits=routing_logits, + routing_bias=routing_bias, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + num_experts=num_experts, + top_k=top_k, + n_group=1, + topk_group=1, + intermediate_size=intermediate_size, + local_expert_offset=0, + local_num_experts=num_experts, + routed_scaling_factor=None, + routing_method_type=RoutingMethodType.DeepSeekV3.value, + enable_pdl=enable_pdl, + output=None, + tune_max_num_tokens=num_tokens + if tune_max_num_tokens is None + else tune_max_num_tokens, + ) + + input_kwargs = { + "hidden_states": hidden_states, + "gemm1_weights": w13, + "gemm1_weights_scale": w13_scale, + "gemm2_weights": w2, + "gemm2_weights_scale": w2_scale, + } + def bench(do_autotune): with autotune(do_autotune): - fn() + fn(**input_kwargs) ms_list = bench_gpu_time( fn, dry_run_iters=warmups, repeat_iters=iterations, + enable_cupti=True, + use_cuda_graph=True, + input_kwargs=input_kwargs, + cold_l2_cache=True, ) median_ms = np.median(ms_list) return median_ms @@ -334,6 +467,7 @@ def bench(do_autotune): "NvFP4xNvFP4", "MxFP4xMxFP8", "MxFP4xBf16", + "MxInt4xBf16", "Fp8-Per-Tensor", "Fp8-Block", ], @@ -369,29 +503,22 @@ def bench(do_autotune): help=f"Type of activation function: {[e.name for e in ActivationType]}", ) args = parser.parse_args() - if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: - bench_trtllm_gen_fused_moe_autotuner_fp8( - args.tune_max_num_tokens, - args.quant_mode, - args.num_tokens, - args.num_experts, - args.hidden_size, - args.intermediate_size, - args.top_k, - args.warmups, - args.iterations, - args.activation_type, - ) - else: - bench_trtllm_gen_fused_moe_autotuner_fp4( - args.tune_max_num_tokens, - args.quant_mode, - args.num_tokens, - args.num_experts, - args.hidden_size, - args.intermediate_size, - args.top_k, - args.warmups, - args.iterations, - args.activation_type, - ) + fn = ( + bench_trtllm_gen_fused_moe_autotuner_fp8 + if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"] + else bench_trtllm_gen_fused_moe_autotuner_mxint4 + if args.quant_mode == "MxInt4xBf16" + else bench_trtllm_gen_fused_moe_autotuner_fp4 + ) + fn( + args.tune_max_num_tokens, + args.quant_mode, + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.warmups, + args.iterations, + args.activation_type, + ) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 2821ce829a..2f403078bd 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -2144,10 +2144,8 @@ def trtllm_bf16_moe( Must be bfloat16 if provided. hidden_states: [seq_len, hidden_size] tensor of input hidden states. Must be bfloat16. - gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights. - Must be bfloat16. - gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights. - Must be bfloat16. + gemm1_weights: [num_experts, 2*intermediate_size // 128, hidden_size // 128, 128] tensor of first layer weights. must be bfloat16. + gemm2_weights: [num_experts, hidden_size//128, intermediate_size, 128] tensor of second layer weights. must be bfloat16. num_experts: Total number of experts. top_k: Number of experts to route to per token. n_group: Number of expert groups. @@ -2163,10 +2161,7 @@ def trtllm_bf16_moe( - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) use_shuffled_weight: Whether to use shuffled weight layout for optimization (default: True). - weight_layout: Weight layout format (default: WeightLayout.BlockMajorK). - - 0: MajorK - K-major layout [Mn, K] - - 1: MajorMn - M-major for A and N-major for B [K, Mn] - - 2: BlockMajorK - Blocked along K dimension [K/blockK, Mn, blockK] + weight_layout: Weight layout format. must be WeightLayout.BlockMajorK ([K/blockK, Mn, blockK]) enable_pdl: Whether to enable Programmatic Dependent Launch. Auto-enabled for >= sm90. tune_max_num_tokens: Maximum number of tokens for autotuning (default: 8192). @@ -2311,9 +2306,13 @@ def trtllm_fp8_block_scale_moe( routing_bias: [num_experts] tensor of routing bias hidden_states: [seq_len, hidden_size] tensor of input hidden states hidden_states_scale: [hidden_size//128, seq_len] tensor of hidden states block scales - gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights + gemm1_weights: tensor of first layer weights + - [num_experts, 2*intermediate_size, hidden_size] if weight_layout == WeightLayout.MajorK + - [num_experts, 2*intermediate_size // 128, hidden_size, 128] if weight_layout == WeightLayout.BlockMajorK gemm1_weights_scale: [num_experts, 2*intermediate_size//128, hidden_size//128] tensor of first layer block scales - gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights + gemm2_weights: tensor of second layer weights + - [num_experts, hidden_size, intermediate_size] if weight_layout == WeightLayout.MajorK + - [num_experts, hidden_size//128, intermediate_size, 128] if weight_layout == WeightLayout.BlockMajorK gemm2_weights_scale: [num_experts, hidden_size//128, intermediate_size//128] tensor of second layer block scales num_experts: Total number of experts top_k: Number of experts to route to per token @@ -2324,6 +2323,9 @@ def trtllm_fp8_block_scale_moe( local_num_experts: Number of experts handled by this device routed_scaling_factor: Scaling factor for routing routing_method_type: Type of routing method to use (default: 0) + weight_layout: Weight layout format (default: WeightLayout.MajorK). Supported layouts: + - 0: MajorK - K-major layout [Mn, K] + - 2: BlockMajorK - Blocked along K dimension [K/blockK, Mn, blockK] enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) Returns: