@@ -192,13 +192,13 @@ def _run_moe_computation(self, runtime_args):
192192 hidden_states_scale = input_quantized ["hidden_states_scale" ],
193193 gemm1_weights = self .static_data ["gemm1_weights_fp4_shuffled" ],
194194 gemm1_weights_scale = self .static_data ["gemm1_scales_fp4_shuffled" ],
195- gemm1_bias = None ,
195+ gemm1_bias = self . config [ "gemm1_bias" ] ,
196196 gemm1_alpha = None ,
197197 gemm1_beta = None ,
198198 gemm1_clamp_limit = None ,
199199 gemm2_weights = self .static_data ["gemm2_weights_fp4_shuffled" ],
200200 gemm2_weights_scale = self .static_data ["gemm2_scales_fp4_shuffled" ],
201- gemm2_bias = None ,
201+ gemm2_bias = self . config [ "gemm2_bias" ] ,
202202 output1_scale_scalar = self .static_data ["scale_c_fc1" ],
203203 output1_scale_gate_scalar = self .static_data ["scale_gate_fc1" ],
204204 output2_scale_scalar = self .static_data ["scale_c_fc2" ],
@@ -570,6 +570,8 @@ def call_moe(
570570 activation_type = kwargs ["activation_type" ]
571571 routing_method_type = kwargs ["routing_method_type" ]
572572 enable_autotune = kwargs .get ("enable_autotune" , True )
573+ gemm1_bias = kwargs ["gemm1_bias" ]
574+ gemm2_bias = kwargs ["gemm2_bias" ]
573575
574576 # Create CUDA graph configuration
575577 config = {
@@ -583,6 +585,8 @@ def call_moe(
583585 "activation_type" : activation_type ,
584586 "routing_method_type" : routing_method_type ,
585587 "enable_autotune" : enable_autotune ,
588+ "gemm1_bias" : gemm1_bias ,
589+ "gemm2_bias" : gemm2_bias ,
586590 }
587591
588592 runtime_args = {
@@ -1561,6 +1565,8 @@ def __init__(
15611565 permute_info ,
15621566 use_routing_scales_on_input ,
15631567 activation_type ,
1568+ gemm1_bias = None ,
1569+ gemm2_bias = None ,
15641570 ):
15651571 self .num_tokens = num_tokens
15661572 self .num_experts = num_experts
@@ -1581,6 +1587,8 @@ def __init__(
15811587 self .permute_info = permute_info
15821588 self .use_routing_scales_on_input = use_routing_scales_on_input
15831589 self .activation_type = activation_type
1590+ self .gemm1_bias = gemm1_bias
1591+ self .gemm2_bias = gemm2_bias
15841592
15851593
15861594class moe_args_dequant :
@@ -1602,6 +1610,8 @@ def __init__(
16021610 use_routing_scales_on_input ,
16031611 activation_type ,
16041612 hidden_states_scale = None ,
1613+ gemm1_bias = None ,
1614+ gemm2_bias = None ,
16051615 ):
16061616 self .num_tokens = num_tokens
16071617 self .num_experts = num_experts
@@ -1617,6 +1627,8 @@ def __init__(
16171627 self .use_routing_scales_on_input = use_routing_scales_on_input
16181628 self .activation_type = activation_type
16191629 self .hidden_states_scale = hidden_states_scale
1630+ self .gemm1_bias = gemm1_bias
1631+ self .gemm2_bias = gemm2_bias
16201632
16211633
16221634def routing_reference (expertLogits , topK , padding ):
@@ -2088,6 +2100,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
20882100 my_a = permute_output [i : i + my_num_tokens ]
20892101 my_b = args .gemm1_weights [expert_idx ]
20902102 my_c = my_a @ my_b .t ()
2103+ if args .gemm1_bias is not None :
2104+ my_c = my_c + args .gemm1_bias [expert_idx ].to (torch .float )
20912105 gemm1_output [i : i + my_num_tokens ] = my_c
20922106 i += my_num_tokens
20932107 i = (i + args .padding - 1 ) // args .padding * args .padding
@@ -2180,6 +2194,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
21802194 my_a = activation_output [i : i + my_num_tokens ]
21812195 my_b = args .gemm2_weights [expert_idx ]
21822196 my_c = my_a @ my_b .t ()
2197+ if args .gemm2_bias is not None :
2198+ my_c = my_c + args .gemm2_bias [expert_idx ].to (torch .float )
21832199 gemm2_output [i : i + my_num_tokens ] = my_c
21842200 i += my_num_tokens
21852201 i = (i + args .padding - 1 ) // args .padding * args .padding
@@ -2262,6 +2278,8 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode):
22622278 args .permute_info ,
22632279 args .use_routing_scales_on_input ,
22642280 args .activation_type ,
2281+ gemm1_bias = args .gemm1_bias ,
2282+ gemm2_bias = args .gemm2_bias ,
22652283 )
22662284
22672285 return run_moe_dequant (args_dequant , quant_mode ), args_dequant
@@ -2365,6 +2383,8 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n):
23652383 args .permute_info ,
23662384 args .use_routing_scales_on_input ,
23672385 args .activation_type ,
2386+ gemm1_bias = args .gemm1_bias ,
2387+ gemm2_bias = args .gemm2_bias ,
23682388 )
23692389
23702390 return run_moe_dequant (
@@ -2404,6 +2424,8 @@ def run_moe_reference_per_tensor_scale_fp8(args):
24042424 args .permute_info ,
24052425 args .use_routing_scales_on_input ,
24062426 args .activation_type ,
2427+ gemm1_bias = args .gemm1_bias ,
2428+ gemm2_bias = args .gemm2_bias ,
24072429 )
24082430
24092431 return run_moe_dequant (args_dequant , QuantMode .FP8_PER_TENSOR ), args_dequant
@@ -2435,6 +2457,8 @@ def run_moe_reference_bf16(args):
24352457 args .permute_info ,
24362458 args .use_routing_scales_on_input ,
24372459 args .activation_type ,
2460+ gemm1_bias = args .gemm1_bias ,
2461+ gemm2_bias = args .gemm2_bias ,
24382462 )
24392463
24402464 return run_moe_dequant (args_dequant , QuantMode .BF16 ), args_dequant
@@ -2486,6 +2510,8 @@ def dequantize(weights, scales):
24862510 args .permute_info ,
24872511 args .use_routing_scales_on_input ,
24882512 args .activation_type ,
2513+ gemm1_bias = args .gemm1_bias ,
2514+ gemm2_bias = args .gemm2_bias ,
24892515 )
24902516
24912517 return run_moe_dequant (args_dequant , QuantMode .MXINT4_BF16_BF16 ), args_dequant
@@ -2523,6 +2549,8 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
25232549 "hidden_states_scale" : args .hidden_states_scale ,
25242550 "hidden_states_quant" : kwargs ["hidden_states_quant" ],
25252551 "enable_autotune" : kwargs .get ("enable_autotune" , True ),
2552+ "gemm1_bias" : args .gemm1_bias ,
2553+ "gemm2_bias" : args .gemm2_bias ,
25262554 }
25272555
25282556 return moe_impl .call_moe (
@@ -2550,6 +2578,8 @@ def run_moe_test(
25502578 activation_type ,
25512579 cache_permute_indices ,
25522580 zero_hidden_states = False ,
2581+ gemm1_bias = None ,
2582+ gemm2_bias = None ,
25532583):
25542584 """Common test logic for all routing methods."""
25552585 skip_checks (
@@ -2699,6 +2729,8 @@ def run_moe_test(
26992729 permute_info ,
27002730 use_routing_scales_on_input ,
27012731 activation_type ,
2732+ gemm1_bias = gemm1_bias ,
2733+ gemm2_bias = gemm2_bias ,
27022734 )
27032735
27042736 # Compute reference output
@@ -3245,3 +3277,56 @@ def test_llama4_routing(
32453277 activation_type ,
32463278 cache_permute_indices ,
32473279 )
3280+
3281+
3282+ @pytest .mark .parametrize ("num_tokens" , [32 , 768 , 3072 ])
3283+ @pytest .mark .parametrize ("hidden_size" , [1024 ])
3284+ @pytest .mark .parametrize ("intermediate_size" , [2048 , 1024 , 768 , 512 ])
3285+ @pytest .mark .parametrize ("bias" , ["gemm2" , "gemm1" , "gemm1_and_gemm2" ])
3286+ def test_nvfp4_moe_gemm_bias (
3287+ num_tokens , hidden_size , intermediate_size , bias , cache_permute_indices
3288+ ):
3289+ """Test NvFP4 MoE with GEMM bias support."""
3290+ num_experts = 8
3291+ top_k = 2
3292+ device = "cuda"
3293+
3294+ gemm1_bias = None
3295+ gemm2_bias = None
3296+ if "gemm1" in bias :
3297+ gemm1_bias = torch .randn (
3298+ (num_experts , 2 * intermediate_size ), device = device , dtype = torch .float32
3299+ )
3300+ if "gemm2" in bias :
3301+ gemm2_bias = torch .randn (
3302+ (num_experts , hidden_size ), device = device , dtype = torch .float32
3303+ )
3304+
3305+ run_moe_test (
3306+ num_tokens = num_tokens ,
3307+ hidden_size = hidden_size ,
3308+ intermediate_size = intermediate_size ,
3309+ moe_impl = FP4Moe (quant_mode = QuantMode .FP4_NVFP4_NVFP4 ),
3310+ routing_config = {
3311+ "num_experts" : num_experts ,
3312+ "top_k" : top_k ,
3313+ "padding" : 8 ,
3314+ "n_groups" : None ,
3315+ "top_k_groups" : None ,
3316+ "routed_scaling" : None ,
3317+ "has_routing_bias" : False ,
3318+ "routing_method_type" : RoutingMethodType .Renormalize ,
3319+ "compatible_moe_impls" : [FP4Moe ],
3320+ "compatible_intermediate_size" : [512 , 768 , 1024 , 2048 ],
3321+ "enable_autotune" : True ,
3322+ },
3323+ weight_processing = {
3324+ "use_shuffled_weight" : True ,
3325+ "layout" : WeightLayout .MajorK ,
3326+ "compatible_moe_impls" : [FP4Moe , FP8PerTensorMoe , FP8BlockScaleMoe ],
3327+ },
3328+ activation_type = ActivationType .Swiglu ,
3329+ cache_permute_indices = cache_permute_indices ,
3330+ gemm1_bias = gemm1_bias ,
3331+ gemm2_bias = gemm2_bias ,
3332+ )
0 commit comments