Skip to content

Commit 3000467

Browse files
authored
tests: add bias testing to nvfp4 moe (#2585)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Added a test exercising GEMM bias combinations for low-precision (FP4) mixture-of-experts (bias on GEMM1, GEMM2, or both). * Extended the test harness to validate GEMM biases across production, CUDA-graph, and reference/dequant paths. * **Refactor** * Extended argument/config surfaces to accept and propagate GEMM biases through all runtime and reference paths, ensuring biases are applied where relevant. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 9d733dd commit 3000467

File tree

1 file changed

+87
-2
lines changed

1 file changed

+87
-2
lines changed

β€Žtests/moe/test_trtllm_gen_fused_moe.pyβ€Ž

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,13 @@ def _run_moe_computation(self, runtime_args):
192192
hidden_states_scale=input_quantized["hidden_states_scale"],
193193
gemm1_weights=self.static_data["gemm1_weights_fp4_shuffled"],
194194
gemm1_weights_scale=self.static_data["gemm1_scales_fp4_shuffled"],
195-
gemm1_bias=None,
195+
gemm1_bias=self.config["gemm1_bias"],
196196
gemm1_alpha=None,
197197
gemm1_beta=None,
198198
gemm1_clamp_limit=None,
199199
gemm2_weights=self.static_data["gemm2_weights_fp4_shuffled"],
200200
gemm2_weights_scale=self.static_data["gemm2_scales_fp4_shuffled"],
201-
gemm2_bias=None,
201+
gemm2_bias=self.config["gemm2_bias"],
202202
output1_scale_scalar=self.static_data["scale_c_fc1"],
203203
output1_scale_gate_scalar=self.static_data["scale_gate_fc1"],
204204
output2_scale_scalar=self.static_data["scale_c_fc2"],
@@ -570,6 +570,8 @@ def call_moe(
570570
activation_type = kwargs["activation_type"]
571571
routing_method_type = kwargs["routing_method_type"]
572572
enable_autotune = kwargs.get("enable_autotune", True)
573+
gemm1_bias = kwargs["gemm1_bias"]
574+
gemm2_bias = kwargs["gemm2_bias"]
573575

574576
# Create CUDA graph configuration
575577
config = {
@@ -583,6 +585,8 @@ def call_moe(
583585
"activation_type": activation_type,
584586
"routing_method_type": routing_method_type,
585587
"enable_autotune": enable_autotune,
588+
"gemm1_bias": gemm1_bias,
589+
"gemm2_bias": gemm2_bias,
586590
}
587591

588592
runtime_args = {
@@ -1561,6 +1565,8 @@ def __init__(
15611565
permute_info,
15621566
use_routing_scales_on_input,
15631567
activation_type,
1568+
gemm1_bias=None,
1569+
gemm2_bias=None,
15641570
):
15651571
self.num_tokens = num_tokens
15661572
self.num_experts = num_experts
@@ -1581,6 +1587,8 @@ def __init__(
15811587
self.permute_info = permute_info
15821588
self.use_routing_scales_on_input = use_routing_scales_on_input
15831589
self.activation_type = activation_type
1590+
self.gemm1_bias = gemm1_bias
1591+
self.gemm2_bias = gemm2_bias
15841592

15851593

15861594
class moe_args_dequant:
@@ -1602,6 +1610,8 @@ def __init__(
16021610
use_routing_scales_on_input,
16031611
activation_type,
16041612
hidden_states_scale=None,
1613+
gemm1_bias=None,
1614+
gemm2_bias=None,
16051615
):
16061616
self.num_tokens = num_tokens
16071617
self.num_experts = num_experts
@@ -1617,6 +1627,8 @@ def __init__(
16171627
self.use_routing_scales_on_input = use_routing_scales_on_input
16181628
self.activation_type = activation_type
16191629
self.hidden_states_scale = hidden_states_scale
1630+
self.gemm1_bias = gemm1_bias
1631+
self.gemm2_bias = gemm2_bias
16201632

16211633

16221634
def routing_reference(expertLogits, topK, padding):
@@ -2088,6 +2100,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
20882100
my_a = permute_output[i : i + my_num_tokens]
20892101
my_b = args.gemm1_weights[expert_idx]
20902102
my_c = my_a @ my_b.t()
2103+
if args.gemm1_bias is not None:
2104+
my_c = my_c + args.gemm1_bias[expert_idx].to(torch.float)
20912105
gemm1_output[i : i + my_num_tokens] = my_c
20922106
i += my_num_tokens
20932107
i = (i + args.padding - 1) // args.padding * args.padding
@@ -2180,6 +2194,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
21802194
my_a = activation_output[i : i + my_num_tokens]
21812195
my_b = args.gemm2_weights[expert_idx]
21822196
my_c = my_a @ my_b.t()
2197+
if args.gemm2_bias is not None:
2198+
my_c = my_c + args.gemm2_bias[expert_idx].to(torch.float)
21832199
gemm2_output[i : i + my_num_tokens] = my_c
21842200
i += my_num_tokens
21852201
i = (i + args.padding - 1) // args.padding * args.padding
@@ -2262,6 +2278,8 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode):
22622278
args.permute_info,
22632279
args.use_routing_scales_on_input,
22642280
args.activation_type,
2281+
gemm1_bias=args.gemm1_bias,
2282+
gemm2_bias=args.gemm2_bias,
22652283
)
22662284

22672285
return run_moe_dequant(args_dequant, quant_mode), args_dequant
@@ -2365,6 +2383,8 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n):
23652383
args.permute_info,
23662384
args.use_routing_scales_on_input,
23672385
args.activation_type,
2386+
gemm1_bias=args.gemm1_bias,
2387+
gemm2_bias=args.gemm2_bias,
23682388
)
23692389

23702390
return run_moe_dequant(
@@ -2404,6 +2424,8 @@ def run_moe_reference_per_tensor_scale_fp8(args):
24042424
args.permute_info,
24052425
args.use_routing_scales_on_input,
24062426
args.activation_type,
2427+
gemm1_bias=args.gemm1_bias,
2428+
gemm2_bias=args.gemm2_bias,
24072429
)
24082430

24092431
return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant
@@ -2435,6 +2457,8 @@ def run_moe_reference_bf16(args):
24352457
args.permute_info,
24362458
args.use_routing_scales_on_input,
24372459
args.activation_type,
2460+
gemm1_bias=args.gemm1_bias,
2461+
gemm2_bias=args.gemm2_bias,
24382462
)
24392463

24402464
return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant
@@ -2486,6 +2510,8 @@ def dequantize(weights, scales):
24862510
args.permute_info,
24872511
args.use_routing_scales_on_input,
24882512
args.activation_type,
2513+
gemm1_bias=args.gemm1_bias,
2514+
gemm2_bias=args.gemm2_bias,
24892515
)
24902516

24912517
return run_moe_dequant(args_dequant, QuantMode.MXINT4_BF16_BF16), args_dequant
@@ -2523,6 +2549,8 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
25232549
"hidden_states_scale": args.hidden_states_scale,
25242550
"hidden_states_quant": kwargs["hidden_states_quant"],
25252551
"enable_autotune": kwargs.get("enable_autotune", True),
2552+
"gemm1_bias": args.gemm1_bias,
2553+
"gemm2_bias": args.gemm2_bias,
25262554
}
25272555

25282556
return moe_impl.call_moe(
@@ -2550,6 +2578,8 @@ def run_moe_test(
25502578
activation_type,
25512579
cache_permute_indices,
25522580
zero_hidden_states=False,
2581+
gemm1_bias=None,
2582+
gemm2_bias=None,
25532583
):
25542584
"""Common test logic for all routing methods."""
25552585
skip_checks(
@@ -2699,6 +2729,8 @@ def run_moe_test(
26992729
permute_info,
27002730
use_routing_scales_on_input,
27012731
activation_type,
2732+
gemm1_bias=gemm1_bias,
2733+
gemm2_bias=gemm2_bias,
27022734
)
27032735

27042736
# Compute reference output
@@ -3245,3 +3277,56 @@ def test_llama4_routing(
32453277
activation_type,
32463278
cache_permute_indices,
32473279
)
3280+
3281+
3282+
@pytest.mark.parametrize("num_tokens", [32, 768, 3072])
3283+
@pytest.mark.parametrize("hidden_size", [1024])
3284+
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512])
3285+
@pytest.mark.parametrize("bias", ["gemm2", "gemm1", "gemm1_and_gemm2"])
3286+
def test_nvfp4_moe_gemm_bias(
3287+
num_tokens, hidden_size, intermediate_size, bias, cache_permute_indices
3288+
):
3289+
"""Test NvFP4 MoE with GEMM bias support."""
3290+
num_experts = 8
3291+
top_k = 2
3292+
device = "cuda"
3293+
3294+
gemm1_bias = None
3295+
gemm2_bias = None
3296+
if "gemm1" in bias:
3297+
gemm1_bias = torch.randn(
3298+
(num_experts, 2 * intermediate_size), device=device, dtype=torch.float32
3299+
)
3300+
if "gemm2" in bias:
3301+
gemm2_bias = torch.randn(
3302+
(num_experts, hidden_size), device=device, dtype=torch.float32
3303+
)
3304+
3305+
run_moe_test(
3306+
num_tokens=num_tokens,
3307+
hidden_size=hidden_size,
3308+
intermediate_size=intermediate_size,
3309+
moe_impl=FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4),
3310+
routing_config={
3311+
"num_experts": num_experts,
3312+
"top_k": top_k,
3313+
"padding": 8,
3314+
"n_groups": None,
3315+
"top_k_groups": None,
3316+
"routed_scaling": None,
3317+
"has_routing_bias": False,
3318+
"routing_method_type": RoutingMethodType.Renormalize,
3319+
"compatible_moe_impls": [FP4Moe],
3320+
"compatible_intermediate_size": [512, 768, 1024, 2048],
3321+
"enable_autotune": True,
3322+
},
3323+
weight_processing={
3324+
"use_shuffled_weight": True,
3325+
"layout": WeightLayout.MajorK,
3326+
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
3327+
},
3328+
activation_type=ActivationType.Swiglu,
3329+
cache_permute_indices=cache_permute_indices,
3330+
gemm1_bias=gemm1_bias,
3331+
gemm2_bias=gemm2_bias,
3332+
)

0 commit comments

Comments
Β (0)