Skip to content

Commit f9cd034

Browse files
authored
Feature: Support non-gated activation in cutlass fused MoE nvfp4 (#2011)
## πŸ“Œ Description This PR removes an assertion in the cutlass fused moe bindings to enable non-gated activations in nvfp4. It also adds a test for this path with relu2 activation. ## πŸ” Related Issues N/A ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [v] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [v] I have installed the hooks with `pre-commit install`. - [v] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [v] Tests have been added or updated as needed. - [v] All tests are passing (`unittest`, etc.). ## Reviewer Notes N/A <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enhanced quantized Mixture of Experts models to support configurable activation types (Swiglu and ReLU2) in the NVFP4 quantization path. * Improved parameter handling to correctly adapt weight shapes and quantization settings based on the selected activation type. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Omer Ullman Argov <[email protected]>
1 parent 1181c5d commit f9cd034

File tree

2 files changed

+76
-36
lines changed

2 files changed

+76
-36
lines changed

β€Žcsrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cuβ€Ž

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
361361
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
362362
base_activation_type, parallelism_config, min_latency_mode);
363363

364-
auto const quant_params =
365-
getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
364+
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size,
365+
quant_scales, base_activation_type);
366366
kernels::MoeMinLatencyParams min_latency_params{};
367367

368368
// TODO: support lora in the future
@@ -542,8 +542,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
542542
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
543543
base_activation_type, parallelism_config, min_latency_mode);
544544

545-
auto const quant_params =
546-
getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales);
545+
auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size,
546+
quant_scales, base_activation_type);
547547

548548
// TODO: support lora in the future
549549
::tensorrt_llm::kernels::LoraParams lora_params{};
@@ -809,9 +809,10 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
809809
return info;
810810
}
811811

812-
kernels::QuantParams getQuantParams(int64_t num_experts_on_rank, int64_t hidden_size,
813-
int64_t inter_size,
814-
Optional<Array<Tensor>> quant_scales) const {
812+
kernels::QuantParams getQuantParams(
813+
int64_t num_experts_on_rank, int64_t hidden_size, int64_t inter_size,
814+
Optional<Array<Tensor>> quant_scales,
815+
ActivationType base_activation_type = ActivationType::Swiglu) const {
815816
if (isFp8Quant()) {
816817
TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for fp8 quantization";
817818
TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 4)
@@ -1013,18 +1014,34 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj {
10131014
// Check shapes
10141015
TVM_FFI_ICHECK(fc1_act_global.ndim() == 0 || fc1_act_global.size(0) == num_experts_on_rank)
10151016
<< "fc1 act global must be scalar or (num_experts_on_rank,)";
1016-
TVM_FFI_ICHECK(
1017-
fc1_weight_block.size(0) == num_experts_on_rank &&
1018-
fc1_weight_block.size(1) ==
1019-
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
1020-
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) *
1021-
2 &&
1022-
fc1_weight_block.size(2) * FP8_PER_INT32 *
1023-
TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize ==
1024-
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
1025-
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4))
1026-
<< "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 "
1027-
"// block_scale_vector_size)";
1017+
if (isGatedActivation(base_activation_type)) {
1018+
TVM_FFI_ICHECK(
1019+
fc1_weight_block.size(0) == num_experts_on_rank &&
1020+
fc1_weight_block.size(1) ==
1021+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
1022+
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) *
1023+
2 &&
1024+
fc1_weight_block.size(2) * FP8_PER_INT32 *
1025+
TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize ==
1026+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
1027+
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4))
1028+
<< "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // "
1029+
"4 "
1030+
"// block_scale_vector_size)";
1031+
} else {
1032+
TVM_FFI_ICHECK(
1033+
fc1_weight_block.size(0) == num_experts_on_rank &&
1034+
fc1_weight_block.size(1) ==
1035+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
1036+
inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) &&
1037+
fc1_weight_block.size(2) * FP8_PER_INT32 *
1038+
TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize ==
1039+
TmaWarpSpecializedGroupedGemmInput::alignToSfDim(
1040+
hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4))
1041+
<< "fc1 weight block size must be (num_experts_on_rank, inter_size, hidden_size // 4 "
1042+
"// block_scale_vector_size)";
1043+
}
1044+
10281045
TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank)
10291046
<< "fc1 global size must be (num_experts_on_rank,)";
10301047
TVM_FFI_ICHECK(fc2_act_global.ndim() == 0 || fc2_act_global.size(0) == num_experts_on_rank)

β€Žtests/moe/test_trtllm_cutlass_fused_moe.pyβ€Ž

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from contextlib import nullcontext
1818

1919
import pytest
20+
from flashinfer.fused_moe.core import ActivationType
2021
import torch
2122
from torch.nn import functional as F
2223

@@ -137,7 +138,7 @@ def compute_routing(
137138
return routing_weights, selected_experts
138139

139140

140-
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
141+
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type):
141142
B, D = a.shape
142143
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
143144
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
@@ -147,13 +148,26 @@ def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
147148
topk_ids = topk_ids.view(-1)
148149
# w1 needs to be swapped in terms of gate and up_proj
149150

151+
if activation_type == ActivationType.Swiglu:
152+
153+
def act(weight, mask):
154+
m = weight.shape[0]
155+
assert m % 2 == 0
156+
w1_expert, w3_expert = weight[m // 2 :, :], weight[: m // 2, :]
157+
return F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
158+
159+
elif activation_type == ActivationType.Relu2:
160+
161+
def act(weight, mask):
162+
return F.relu(a[mask] @ weight.t()) ** 2
163+
164+
else:
165+
raise ValueError(f"Unsupported activation type {activation_type}")
166+
150167
for i in range(w1.shape[0]):
151168
mask = topk_ids == i
152169
if mask.sum():
153-
m = w1[i].shape[0]
154-
assert m % 2 == 0
155-
w1_expert, w3_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
156-
inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
170+
inter = act(w1[i], mask)
157171
inter_gs = torch.tensor(1.0).cuda()
158172
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
159173
inter = dequantize_nvfp4_to_dtype(
@@ -363,6 +377,11 @@ def test_moe_fp8(
363377
[(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)],
364378
)
365379
@pytest.mark.parametrize("quantized_input", [False, True])
380+
@pytest.mark.parametrize(
381+
"activation_type",
382+
[ActivationType.Swiglu, ActivationType.Relu2],
383+
ids=["swiglu", "relu2"],
384+
)
366385
@pytest.mark.skipif(
367386
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
368387
reason="NVFP4 is only supported on SM100, SM110 and SM120",
@@ -376,6 +395,7 @@ def test_moe_nvfp4(
376395
otype,
377396
wtype,
378397
quantized_input,
398+
activation_type,
379399
):
380400
# Skip invalid configurations
381401
if top_k > num_experts:
@@ -391,10 +411,10 @@ def test_moe_nvfp4(
391411
n = intermediate_size
392412
k = hidden_size
393413

394-
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=otype) / 10
395-
w1_cutlass = torch.cat((w1[:, n:, :], w1[:, :n, :]), dim=1).contiguous()
414+
w1_n = 2 * n if activation_type == ActivationType.Swiglu else n
415+
w1 = torch.randn((e, w1_n, k), device="cuda", dtype=otype) / 10
396416

397-
sf_w1_2n = round_up(2 * n, 128)
417+
sf_w1_2n = round_up(w1_n, 128)
398418
sf_w1_k = round_up(k // quant_blocksize, 4)
399419
w1_blockscale = torch.empty(
400420
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
@@ -409,8 +429,8 @@ def test_moe_nvfp4(
409429
w2_blockscale = torch.empty(
410430
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
411431
)
412-
w1_q = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
413-
w1_q_cutlass = torch.empty((e, 2 * n, k // 2), device="cuda", dtype=torch.uint8)
432+
w1_q = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8)
433+
w1_q_cutlass = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8)
414434
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
415435
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
416436
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
@@ -424,7 +444,7 @@ def test_moe_nvfp4(
424444
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])
425445

426446
w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize(
427-
w1_cutlass[expert], w1_gs[expert]
447+
w1[expert], w1_gs[expert]
428448
)
429449

430450
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])
@@ -469,6 +489,7 @@ def test_moe_nvfp4(
469489
quant_scales=quant_scales,
470490
input_sf=input_sf,
471491
output=flash_output,
492+
activation_type=activation_type,
472493
)
473494

474495
# Ref check
@@ -483,7 +504,7 @@ def test_moe_nvfp4(
483504
block_size=quant_blocksize,
484505
)
485506

486-
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=otype)
507+
w1_d = torch.empty((e, w1_n, k), device="cuda", dtype=otype)
487508
w2_d = torch.empty((e, k, n), device="cuda", dtype=otype)
488509

489510
for idx in range(0, e):
@@ -504,12 +525,14 @@ def test_moe_nvfp4(
504525
block_size=quant_blocksize,
505526
)
506527

507-
w1_q_cutlass = torch.cat((w1_q[:, n:, :], w1_q[:, :n, :]), dim=1).contiguous()
508-
w1_blockscale_cutlass = torch.cat(
509-
(w1_blockscale[:, n:, :], w1_blockscale[:, :n, :]), dim=1
510-
).contiguous()
511528
ref_output = torch_moe_nvfp4(
512-
a_in_dtype, w1_d, w2_d, top_k, routing_weights, selected_experts
529+
a_in_dtype,
530+
w1_d,
531+
w2_d,
532+
top_k,
533+
routing_weights,
534+
selected_experts,
535+
activation_type,
513536
)
514537
torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1)
515538

0 commit comments

Comments
Β (0)