Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions tests/moe/test_trtllm_gen_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,8 @@ def __init__(
permute_info,
use_routing_scales_on_input,
activation_type,
gemm1_bias=None,
gemm2_bias=None,
):
self.num_tokens = num_tokens
self.num_experts = num_experts
Expand All @@ -1455,6 +1457,8 @@ def __init__(
self.permute_info = permute_info
self.use_routing_scales_on_input = use_routing_scales_on_input
self.activation_type = activation_type
self.gemm1_bias = gemm1_bias
self.gemm2_bias = gemm2_bias


class moe_args_dequant:
Expand All @@ -1476,6 +1480,8 @@ def __init__(
use_routing_scales_on_input,
activation_type,
hidden_states_scale=None,
gemm1_bias=None,
gemm2_bias=None,
):
self.num_tokens = num_tokens
self.num_experts = num_experts
Expand All @@ -1491,6 +1497,8 @@ def __init__(
self.use_routing_scales_on_input = use_routing_scales_on_input
self.activation_type = activation_type
self.hidden_states_scale = hidden_states_scale
self.gemm1_bias = gemm1_bias
self.gemm2_bias = gemm2_bias


def routing_reference(expertLogits, topK, padding):
Expand Down Expand Up @@ -1929,6 +1937,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
my_a = permute_output[i : i + my_num_tokens]
my_b = args.gemm1_weights[expert_idx]
my_c = my_a @ my_b.t()
if args.gemm1_bias is not None:
my_c = my_c + args.gemm1_bias[expert_idx].to(torch.float)
gemm1_output[i : i + my_num_tokens] = my_c
i += my_num_tokens
i = (i + args.padding - 1) // args.padding * args.padding
Expand Down Expand Up @@ -2018,6 +2028,8 @@ def run_moe_dequant(args, quant_mode: QuantMode):
my_a = activation_output[i : i + my_num_tokens]
my_b = args.gemm2_weights[expert_idx]
my_c = my_a @ my_b.t()
if args.gemm2_bias is not None:
my_c = my_c + args.gemm2_bias[expert_idx].to(torch.float)
gemm2_output[i : i + my_num_tokens] = my_c
i += my_num_tokens
i = (i + args.padding - 1) // args.padding * args.padding
Expand Down Expand Up @@ -2100,6 +2112,8 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, quant_mode), args_dequant
Expand Down Expand Up @@ -2165,6 +2179,8 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant
Expand Down Expand Up @@ -2202,6 +2218,8 @@ def run_moe_reference_per_tensor_scale_fp8(args):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant
Expand Down Expand Up @@ -2233,6 +2251,8 @@ def run_moe_reference_bf16(args):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant
Expand Down Expand Up @@ -2284,6 +2304,8 @@ def dequantize(weights, scales):
args.permute_info,
args.use_routing_scales_on_input,
args.activation_type,
gemm1_bias=args.gemm1_bias,
gemm2_bias=args.gemm2_bias,
)

return run_moe_dequant(args_dequant, QuantMode.MXINT4_BF16_BF16), args_dequant
Expand Down Expand Up @@ -3029,3 +3051,246 @@ def test_llama4_routing(
activation_type,
cache_permute_indices,
)


# ====================================================================================
# Bias Support Tests for NvFP4 MoE
# ====================================================================================


def _run_fp4_moe_with_bias(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
gemm1_bias=None,
gemm2_bias=None,
):
device = "cuda"
activation_type = ActivationType.Swiglu
padding = 8

moe_impl = FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4)
moe_impl._cache_permute_indices = {}

hidden_states = 2 * torch.randn(
(num_tokens, hidden_size), device=device, dtype=torch.bfloat16
)
gemm1_weights = torch.randn(
(num_experts, 2 * intermediate_size, hidden_size),
device=device,
dtype=torch.bfloat16,
)
gemm2_weights = torch.randn(
(num_experts, hidden_size, intermediate_size),
device=device,
dtype=torch.bfloat16,
)
expert_logits = torch.randn(
(num_tokens, num_experts), device=device, dtype=torch.bfloat16
)

weights_data = moe_impl.quantize_weights(
gemm1_weights, gemm2_weights, hidden_states
)
inputs_data = moe_impl.quantize_inputs(
hidden_states, weights_data["hidden_states_scale_global"]
)
quant_data = {**weights_data, **inputs_data}

permute_info, scores = routing_reference_renormalize(
expert_logits, top_k, num_experts, padding
)
args = moe_args(
num_tokens,
num_experts,
hidden_size,
intermediate_size,
top_k,
padding,
quant_data["hidden_states"],
quant_data["hidden_states_scale"],
quant_data["hidden_states_scale_global"],
scores,
quant_data["gemm1_weights"],
quant_data["gemm1_scales"],
quant_data["gemm1_scales_global"],
quant_data["gemm2_weights"],
quant_data["gemm2_scales"],
quant_data["gemm2_scales_global"],
permute_info,
False,
activation_type,
gemm1_bias=gemm1_bias,
gemm2_bias=gemm2_bias,
)

ref_output, args_dequant = moe_impl.compute_reference(args)
static_data = moe_impl.prepare_static_weights_for_kernel(
args_dequant,
args,
gemm1_weights,
gemm2_weights,
hidden_size,
intermediate_size,
num_experts,
{"shuffle": True, "layout": WeightLayout.MajorK},
)

kernel_inputs = moe_impl.quantize_inputs(
hidden_states,
weights_data["hidden_states_scale_global"],
is_swizzling=False,
)

kernel_output = trtllm_fp4_block_scale_moe(
routing_logits=expert_logits,
routing_bias=None,
hidden_states=kernel_inputs["hidden_states"],
hidden_states_scale=kernel_inputs["hidden_states_scale"],
gemm1_weights=static_data["gemm1_weights_fp4_shuffled"],
gemm1_weights_scale=static_data["gemm1_scales_fp4_shuffled"],
gemm1_bias=gemm1_bias,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=static_data["gemm2_weights_fp4_shuffled"],
gemm2_weights_scale=static_data["gemm2_scales_fp4_shuffled"],
gemm2_bias=gemm2_bias,
output1_scale_scalar=static_data["scale_c_fc1"],
output1_scale_gate_scalar=static_data["scale_gate_fc1"],
output2_scale_scalar=static_data["scale_c_fc2"],
num_experts=num_experts,
top_k=top_k,
n_group=None,
topk_group=None,
intermediate_size=intermediate_size,
local_expert_offset=0,
local_num_experts=num_experts,
routed_scaling_factor=None,
routing_method_type=RoutingMethodType.Renormalize.value,
do_finalize=True,
activation_type=activation_type.value,
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
)

return kernel_output, ref_output


@pytest.mark.parametrize("num_tokens", [32, 768, 3072])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512])
def test_nvfp4_moe_gemm2_bias(num_tokens, hidden_size, intermediate_size):
from flashinfer.utils import get_compute_capability

cc = get_compute_capability(torch.device("cuda"))
if cc[0] not in [10]:
pytest.skip("Requires SM100/SM103 GPU")

num_experts, top_k = 8, 2
device = "cuda"

# gemm2_bias shape: [num_experts, hidden_size], dtype float32
gemm2_bias = torch.randn(
(num_experts, hidden_size), device=device, dtype=torch.float32
)

torch.random.manual_seed(0)
kernel_output, ref_output = _run_fp4_moe_with_bias(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
gemm2_bias=gemm2_bias,
)

tolerances = FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4).get_tolerances()
check_accuracy(
ref_output,
kernel_output[0].to(torch.float),
atol=tolerances["atol"],
rtol=tolerances["rtol"],
percent=tolerances["percent"],
)


@pytest.mark.parametrize("num_tokens", [32, 768, 3072])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512])
def test_nvfp4_moe_gemm1_bias(num_tokens, hidden_size, intermediate_size):
from flashinfer.utils import get_compute_capability

cc = get_compute_capability(torch.device("cuda"))
if cc[0] not in [10]:
pytest.skip("Requires SM100/SM103 GPU")

num_experts, top_k = 8, 2
device = "cuda"

# gemm1_bias shape: [num_experts, 2 * intermediate_size], dtype float32
# (factor of 2 because of gated activation β€” SwiGLU has gate + value)
gemm1_bias = torch.randn(
(num_experts, 2 * intermediate_size), device=device, dtype=torch.float32
)

torch.random.manual_seed(0)
kernel_output, ref_output = _run_fp4_moe_with_bias(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
gemm1_bias=gemm1_bias,
)

tolerances = FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4).get_tolerances()
check_accuracy(
ref_output,
kernel_output[0].to(torch.float),
atol=tolerances["atol"],
rtol=tolerances["rtol"],
percent=tolerances["percent"],
)


@pytest.mark.parametrize("num_tokens", [32, 768, 3072])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512])
def test_nvfp4_moe_both_biases(num_tokens, hidden_size, intermediate_size):
from flashinfer.utils import get_compute_capability

cc = get_compute_capability(torch.device("cuda"))
if cc[0] not in [10]:
pytest.skip("Requires SM100/SM103 GPU")

num_experts, top_k = 8, 2
device = "cuda"

gemm1_bias = torch.randn(
(num_experts, 2 * intermediate_size), device=device, dtype=torch.float32
)
gemm2_bias = torch.randn(
(num_experts, hidden_size), device=device, dtype=torch.float32
)

torch.random.manual_seed(0)
kernel_output, ref_output = _run_fp4_moe_with_bias(
num_tokens,
hidden_size,
intermediate_size,
num_experts,
top_k,
gemm1_bias=gemm1_bias,
gemm2_bias=gemm2_bias,
)

tolerances = FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4).get_tolerances()
check_accuracy(
ref_output,
kernel_output[0].to(torch.float),
atol=tolerances["atol"],
rtol=tolerances["rtol"],
percent=tolerances["percent"],
)
Loading