Skip to content

Commit f9373aa

Browse files
authored
chore: MoE benchmark effective BW fix for trtllm_block_scale_moe (#2341)
<!-- .github/pull_request_template.md --> ## 📌 Description The MoE benchmark script overestimates the num bytes loaded by assuming all experts are active. I saw effective BW exceeds 3x the peak BW of some system as a result. The fix is to calculate the routed experts (topk_ids) on the host side and count the unique number of experts, the same logic `cutlass_fused_moe` does. While investigating the above issue, I also found data init of routing_bias using `rand()` results in very skewed expert distribution (repro cmd below gives 18 active out of 128 experts). I'd like to change it to `ones()*0.1` for smoother expert distribution (noe giving 114 out of 128), while maintaining the same load/compute behavior in the kernels. ``` python3 flashinfer_benchmark.py --routine trtllm_fp4_block_scale_moe --num_tokens 32 --hidden_size 7168 --intermediate_size 2048 --num_experts 128 --routing_method deepseek_v3 --top_k 8 --n_group 8 --topk_group 4 --routed_scaling_factor 2.5 --use_routing_bias --use_shuffled_weight --generate_repro_command -vv ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for nvfp4 and mxfp4 quantization formats in bandwidth calculations. * Introduced routing support for DeepSeekV3 method. * **Improvements** * Enhanced routing bias initialization for more consistent expert distribution. * Expanded routing computation utilities for greater flexibility. * **Tests** * Updated benchmark test data to align with new routing and quantization logic. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent cc1a362 commit f9373aa

File tree

1 file changed

+126
-20
lines changed

1 file changed

+126
-20
lines changed

benchmarks/routines/moe.py

Lines changed: 126 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
trtllm_fp8_per_tensor_scale_moe,
1414
cutlass_fused_moe,
1515
convert_to_block_layout,
16+
fused_topk_deepseek,
1617
)
18+
from flashinfer.fused_moe.core import RoutingMethodType
1719
from flashinfer import fp4_quantize, shuffle_matrix_a
1820
from flashinfer.testing.utils import (
1921
bench_gpu_time,
@@ -316,7 +318,10 @@ def create_trtllm_moe_test_data(
316318
# Create routing bias if needed - always bfloat16
317319
routing_bias = None
318320
if use_routing_bias:
319-
routing_bias = torch.randn(num_experts, device=device, dtype=torch.bfloat16)
321+
# Use uniform routing bias for less skewed expert distribution
322+
routing_bias = (
323+
torch.ones(num_experts, device=device, dtype=torch.bfloat16) * 0.1
324+
)
320325

321326
# Create hidden states - always start with bfloat16 for proper quantization
322327
hidden_states = 2 * torch.randn(
@@ -430,21 +435,24 @@ def calculate_moe_bandwidth(
430435
weight_format: Optional[str] = None,
431436
routing_logits_dtype: Optional[torch.dtype] = torch.float32,
432437
active_experts: Optional[int] = None,
438+
verbose: int = 0,
433439
) -> float:
434440
"""
435441
Calculate memory bandwidth for MOE operation in TB/sec.
436442
437443
Args:
438-
input_format: Override for input representation ("fp8" or "fp4"); None uses dtype.itemsize
439-
weight_format: Override for weight representation ("fp8" or "fp4"); None uses dtype.itemsize
444+
input_format: Override for input representation; None uses dtype.itemsize
445+
weight_format: Override for weight representation; None uses dtype.itemsize
440446
routing_logits_dtype: Dtype for routing logits memory accounting (default float32)
441447
"""
442448

443449
# Get effective byte sizes
444450
def get_effective_bytes(dtype: torch.dtype, fmt: Optional[str]) -> float:
445-
if fmt == "fp4":
446-
return 0.5
447-
if fmt == "fp8":
451+
if fmt == "nvfp4":
452+
return 0.5 + 1 / 16
453+
elif fmt == "mxfp4":
454+
return 0.5 + 1 / 32
455+
elif fmt == "fp8":
448456
return 1.0
449457
return dtype.itemsize
450458

@@ -472,6 +480,9 @@ def get_effective_bytes(dtype: torch.dtype, fmt: Optional[str]) -> float:
472480
num_active_experts = active_experts
473481
else:
474482
num_active_experts = min(num_experts, top_k * num_tokens)
483+
if verbose >= 2:
484+
print(f"[VVERBOSE] num_active_experts = {num_active_experts}")
485+
475486
weight_bytes = num_active_experts * weight_bytes_per_expert
476487

477488
# Output memory (typically full precision)
@@ -490,6 +501,68 @@ def _compute_routing(router_logits: torch.Tensor, top_k: int):
490501
return routing_weights, selected_experts
491502

492503

504+
def _compute_routing_for_method(
505+
routing_logits: torch.Tensor,
506+
routing_bias: Optional[torch.Tensor],
507+
top_k: int,
508+
routing_method_type: int,
509+
n_group: Optional[int] = None,
510+
topk_group: Optional[int] = None,
511+
routed_scaling_factor: Optional[float] = None,
512+
) -> torch.Tensor:
513+
"""
514+
Compute selected experts based on routing method type.
515+
Returns only the selected expert indices tensor.
516+
517+
Args:
518+
routing_logits: [num_tokens, num_experts] routing scores
519+
routing_bias: Optional [num_experts] routing bias
520+
top_k: Number of experts to select per token
521+
routing_method_type: Type of routing method (see RoutingMethodType enum)
522+
n_group: Number of expert groups (for DeepSeekV3)
523+
topk_group: Number of top groups (for DeepSeekV3)
524+
routed_scaling_factor: Scaling factor (for DeepSeekV3)
525+
526+
Returns:
527+
selected_experts: [num_tokens, top_k] tensor of selected expert indices
528+
"""
529+
num_tokens = routing_logits.shape[0]
530+
device = routing_logits.device
531+
532+
if routing_method_type == RoutingMethodType.DeepSeekV3:
533+
# Use fused_topk_deepseek for accurate DeepSeekV3 routing
534+
if n_group is None or topk_group is None or routed_scaling_factor is None:
535+
raise ValueError(
536+
"DeepSeekV3 routing requires n_group, topk_group, and routed_scaling_factor"
537+
)
538+
if routing_bias is None:
539+
routing_bias = torch.zeros(
540+
routing_logits.shape[1], device=device, dtype=routing_logits.dtype
541+
)
542+
543+
# Allocate output tensors
544+
topk_values = torch.empty(num_tokens, top_k, device=device, dtype=torch.float32)
545+
topk_indices = torch.empty(num_tokens, top_k, device=device, dtype=torch.int32)
546+
547+
fused_topk_deepseek(
548+
scores=routing_logits.float(),
549+
bias=routing_bias.float(),
550+
n_group=n_group,
551+
topk_group=topk_group,
552+
topk=top_k,
553+
routed_scaling_factor=routed_scaling_factor,
554+
topk_values=topk_values,
555+
topk_indices=topk_indices,
556+
)
557+
return topk_indices
558+
else:
559+
# For other routing methods, use simple top-k as approximation
560+
# This is accurate for Default, Renormalize, RenormalizeNaive, TopK
561+
# and approximate for Llama4
562+
_, selected_experts = _compute_routing(routing_logits.float(), top_k)
563+
return selected_experts
564+
565+
493566
def _dynamic_per_tensor_fp8_quant(x: torch.Tensor):
494567
fp8_max = torch.finfo(torch.float8_e4m3fn).max
495568
x_max = x.abs().max().float().clamp(min=1e-6)
@@ -588,6 +661,18 @@ def testTrtllmFp4BlockScaleMoe(args):
588661
)
589662
)
590663

664+
# Compute selected experts for accurate bandwidth calculation
665+
# Use the actual routing method to get correct expert assignments
666+
selected_experts = _compute_routing_for_method(
667+
routing_logits=routing_logits,
668+
routing_bias=routing_bias,
669+
top_k=top_k,
670+
routing_method_type=routing_method_type,
671+
n_group=n_group,
672+
topk_group=topk_group,
673+
routed_scaling_factor=routed_scaling_factor,
674+
)
675+
591676
# For FP4, we need to properly quantize weights and create scales
592677
use_ue8m0 = False
593678

@@ -781,9 +866,11 @@ def run_fp4_moe(
781866
median_time,
782867
input_dtype,
783868
weight_dtype,
784-
input_format="fp4",
785-
weight_format="fp4",
869+
input_format="nvfp4",
870+
weight_format="nvfp4",
786871
routing_logits_dtype=routing_logits.dtype,
872+
active_experts=int(selected_experts.unique().numel()),
873+
verbose=args.verbose,
787874
)
788875

789876
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
@@ -1142,20 +1229,11 @@ def run_cutlass(
11421229
median_time,
11431230
input_dtype,
11441231
input_dtype,
1145-
input_format=(
1146-
"fp8"
1147-
if variant == "fp8"
1148-
else (
1149-
"fp4"
1150-
if (variant == "nvfp4" and getattr(args, "quantized_input", False))
1151-
else None
1152-
)
1153-
),
1154-
weight_format=(
1155-
"fp8" if variant == "fp8" else ("fp4" if variant == "nvfp4" else None)
1156-
),
1232+
input_format=variant,
1233+
weight_format=variant,
11571234
routing_logits_dtype=router_logits.dtype,
11581235
active_experts=int(selected_experts.unique().numel()),
1236+
verbose=args.verbose,
11591237
)
11601238

11611239
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
@@ -1278,6 +1356,18 @@ def testTrtllmFp8BlockScaleMoe(args):
12781356
)
12791357
)
12801358

1359+
# Compute selected experts for accurate bandwidth calculation
1360+
# Use the actual routing method to get correct expert assignments
1361+
selected_experts = _compute_routing_for_method(
1362+
routing_logits=routing_logits,
1363+
routing_bias=routing_bias,
1364+
top_k=top_k,
1365+
routing_method_type=routing_method_type,
1366+
n_group=n_group,
1367+
topk_group=topk_group,
1368+
routed_scaling_factor=routed_scaling_factor,
1369+
)
1370+
12811371
# For FP8 block scale, create quantized weights and block scales
12821372
# Quantize to FP8
12831373
gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn)
@@ -1412,6 +1502,8 @@ def run_fp8_block_moe(
14121502
input_format="fp8",
14131503
weight_format="fp8",
14141504
routing_logits_dtype=routing_logits.dtype,
1505+
active_experts=int(selected_experts.unique().numel()),
1506+
verbose=args.verbose,
14151507
)
14161508

14171509
backend = "trtllm"
@@ -1533,6 +1625,18 @@ def testTrtllmFp8PerTensorScaleMoe(args):
15331625
)
15341626
)
15351627

1628+
# Compute selected experts for accurate bandwidth calculation
1629+
# Use the actual routing method to get correct expert assignments
1630+
selected_experts = _compute_routing_for_method(
1631+
routing_logits=routing_logits,
1632+
routing_bias=routing_bias,
1633+
top_k=top_k,
1634+
routing_method_type=routing_method_type,
1635+
n_group=n_group,
1636+
topk_group=topk_group,
1637+
routed_scaling_factor=routed_scaling_factor,
1638+
)
1639+
15361640
# For FP8 per-tensor scale, create quantized weights and per-tensor scales
15371641
# Quantize to FP8
15381642
gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn)
@@ -1630,6 +1734,8 @@ def run_fp8_per_tensor_moe(
16301734
input_format="fp8",
16311735
weight_format="fp8",
16321736
routing_logits_dtype=routing_logits.dtype,
1737+
active_experts=int(selected_experts.unique().numel()),
1738+
verbose=args.verbose,
16331739
)
16341740

16351741
backend = "trtllm"

0 commit comments

Comments
 (0)