Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions tests/moe/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ def _run_moe_computation(self, runtime_args):
hidden_states_scale=input_quantized["hidden_states_scale"],
gemm1_weights=self.static_data["gemm1_weights_fp4_shuffled"],
gemm1_weights_scale=self.static_data["gemm1_scales_fp4_shuffled"],
gemm1_bias=None,
gemm1_bias=self.config["gemm1_bias"],
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=self.static_data["gemm2_weights_fp4_shuffled"],
gemm2_weights_scale=self.static_data["gemm2_scales_fp4_shuffled"],
gemm2_bias=None,
gemm2_bias=self.config["gemm2_bias"],
output1_scale_scalar=self.static_data["scale_c_fc1"],
output1_scale_gate_scalar=self.static_data["scale_gate_fc1"],
output2_scale_scalar=self.static_data["scale_c_fc2"],
Expand Down Expand Up @@ -568,6 +568,8 @@ def call_moe(
activation_type = kwargs["activation_type"]
routing_method_type = kwargs["routing_method_type"]
enable_autotune = kwargs.get("enable_autotune", True)
gemm1_bias = kwargs["gemm1_bias"]
gemm2_bias = kwargs["gemm2_bias"]

# Create CUDA graph configuration
config = {
Expand All @@ -581,6 +583,8 @@ def call_moe(
"activation_type": activation_type,
"routing_method_type": routing_method_type,
"enable_autotune": enable_autotune,
"gemm1_bias": gemm1_bias,
"gemm2_bias": gemm2_bias,
}

runtime_args = {
Expand Down Expand Up @@ -1435,6 +1439,8 @@ def __init__(
permute_info,
use_routing_scales_on_input,
activation_type,
gemm1_bias=None,
gemm2_bias=None,
):
self.num_tokens = num_tokens
self.num_experts = num_experts
Expand All @@ -1455,6 +1461,8 @@ def __init__(
self.permute_info = permute_info
self.use_routing_scales_on_input = use_routing_scales_on_input
self.activation_type = activation_type
self.gemm1_bias = gemm1_bias
self.gemm2_bias = gemm2_bias


class moe_args_dequant:
Expand All @@ -1476,6 +1484,8 @@ def __init__(
use_routing_scales_on_input,
activation_type,
hidden_states_scale=None,
gemm1_bias=None,
gemm2_bias=None,
):
self.num_tokens = num_tokens
self.num_experts = num_experts
Expand All @@ -1491,6 +1501,8 @@ def __init__(
self.use_routing_scales_on_input = use_routing_scales_on_input
self.activation_type = activation_type
self.hidden_states_scale = hidden_states_scale
self.gemm1_bias = gemm1_bias
self.gemm2_bias = gemm2_bias


def routing_reference(expertLogits, topK, padding):
Expand Down Expand Up @@ -1929,6 +1941,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
my_a = permute_output[i : i + my_num_tokens]
my_b = args.gemm1_weights[expert_idx]
my_c = my_a @ my_b.t()
if args.gemm1_bias is not None:
my_c = my_c + args.gemm1_bias[expert_idx].to(torch.float)
gemm1_output[i : i + my_num_tokens] = my_c
i += my_num_tokens
i = (i + args.padding - 1) // args.padding * args.padding
Expand Down Expand Up @@ -2018,6 +2032,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
my_a = activation_output[i : i + my_num_tokens]
my_b = args.gemm2_weights[expert_idx]
my_c = my_a @ my_b.t()
if args.gemm2_bias is not None:
my_c = my_c + args.gemm2_bias[expert_idx].to(torch.float)
gemm2_output[i : i + my_num_tokens] = my_c
i += my_num_tokens
i = (i + args.padding - 1) // args.padding * args.padding
Expand Down Expand Up @@ -2100,6 +2116,8 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, quant_mode), args_dequant
Expand Down Expand Up @@ -2165,6 +2183,8 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant
Expand Down Expand Up @@ -2202,6 +2222,8 @@ def run_moe_reference_per_tensor_scale_fp8(args):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant
Expand Down Expand Up @@ -2233,6 +2255,8 @@ def run_moe_reference_bf16(args):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant
Expand Down Expand Up @@ -2284,6 +2308,8 @@ def dequantize(weights, scales):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.MXINT4_BF16_BF16), args_dequant
Expand Down Expand Up @@ -2321,6 +2347,8 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
"hidden_states_scale": args.hidden_states_scale,
"hidden_states_quant": kwargs["hidden_states_quant"],
"enable_autotune": kwargs.get("enable_autotune", True),
"gemm1_bias": args.gemm1_bias,
"gemm2_bias": args.gemm2_bias,
}

return moe_impl.call_moe(
Expand Down Expand Up @@ -2348,6 +2376,8 @@ def run_moe_test(
activation_type,
cache_permute_indices,
zero_hidden_states=False,
gemm1_bias=None,
gemm2_bias=None,
):
"""Common test logic for all routing methods."""
skip_checks(
Expand Down Expand Up @@ -2497,6 +2527,8 @@ def run_moe_test(
permute_info,
use_routing_scales_on_input,
activation_type,
gemm1_bias=gemm1_bias,
gemm2_bias=gemm2_bias,
)

# Compute reference output
Expand Down Expand Up @@ -3029,3 +3061,56 @@ def test_llama4_routing(
activation_type,
cache_permute_indices,
)


@pytest.mark.parametrize("num_tokens", [32, 768, 3072])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512])
@pytest.mark.parametrize("bias", ["gemm2", "gemm1", "gemm1_and_gemm2"])
def test_nvfp4_moe_gemm_bias(
num_tokens, hidden_size, intermediate_size, bias, cache_permute_indices
):
"""Test NvFP4 MoE with GEMM bias support."""
num_experts = 8
top_k = 2
device = "cuda"

gemm1_bias = None
gemm2_bias = None
if "gemm1" in bias:
gemm1_bias = torch.randn(
(num_experts, 2 * intermediate_size), device=device, dtype=torch.float32
)
if "gemm2" in bias:
gemm2_bias = torch.randn(
(num_experts, hidden_size), device=device, dtype=torch.float32
)

run_moe_test(
num_tokens=num_tokens,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
moe_impl=FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4),
routing_config={
"num_experts": num_experts,
"top_k": top_k,
"padding": 8,
"n_groups": None,
"top_k_groups": None,
"routed_scaling": None,
"has_routing_bias": False,
"routing_method_type": RoutingMethodType.Renormalize,
"compatible_moe_impls": [FP4Moe],
"compatible_intermediate_size": [512, 768, 1024, 2048],
"enable_autotune": True,
},
weight_processing={
"use_shuffled_weight": True,
"layout": WeightLayout.MajorK,
"compatible_moe_impls": [FP4Moe, FP8PerTensorMoe, FP8BlockScaleMoe],
},
activation_type=ActivationType.Swiglu,
cache_permute_indices=cache_permute_indices,
gemm1_bias=gemm1_bias,
gemm2_bias=gemm2_bias,
)
Loading