Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4230,10 +4230,11 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile
mWType == nvinfer1::DataType::kINT64;
TLLM_CHECK_WITH_INFO(!is_4bit_act || is_4bit_weight,
"Cannot have 4-bit activation with non-4-bit weight");
float dtype_bytes =
is_4bit_act ? 0.5f
: static_cast<float>(mWType == nvinfer1::DataType::kINT4 ? getDTypeSize(mOType)
: getDTypeSize(mDType));
float dtype_bytes = is_4bit_act ? 0.5f
: static_cast<float>((mWType == nvinfer1::DataType::kINT4 ||
mWType == nvinfer1::DataType::kUINT8)
? getDTypeSize(mOType)
: getDTypeSize(mDType));
float weight_bytes = is_4bit_weight ? 0.5f : static_cast<float>(getDTypeSize(mWType));
size_t output_bytes = getDTypeSize(mOType);
size_t gemm_output_bytes =
Expand Down Expand Up @@ -4282,10 +4283,12 @@ std::map<std::string, std::pair<size_t, size_t>> GemmProfilerBackend::getProfile

// TODO Make quant 2 & 4 bigger for FP8 if we ever change to scaling per expert
bool is_int_w_quant =
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4) &&
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 ||
mWType == nvinfer1::DataType::kUINT8) &&
mGroupSize <= 0;
bool is_int_groupwise_w_quant =
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4) &&
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 ||
mWType == nvinfer1::DataType::kUINT8) &&
mGroupSize > 0;
Comment on lines 4285 to 4292
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and avoid repeating the same condition, you could introduce a boolean variable for the integer quantization type check.

  bool is_int_quant_type = (mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 ||
                            mWType == nvinfer1::DataType::kUINT8);
  bool is_int_w_quant = is_int_quant_type && mGroupSize <= 0;
  bool is_int_groupwise_w_quant = is_int_quant_type && mGroupSize > 0;




Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with implementing this if you want but I didn't want to override the established style too much.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable

simplifying conditions and adding the immediate var names for readability is encouraged (tho we don't strictly impose the style). thx

bool is_fp8_act_quant = mDType == nvinfer1::DataType::kFP8;
bool is_fp8_w_quant = mWType == nvinfer1::DataType::kFP8;
Expand Down
174 changes: 174 additions & 0 deletions tests/moe/test_trtllm_cutlass_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import flashinfer.fused_moe as fused_moe
from flashinfer import (
autotune,
fp4_quantize,
mxfp4_dequantize,
mxfp4_quantize,
Expand Down Expand Up @@ -1632,5 +1633,178 @@ def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor:
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1)


# Reduced parameter set vs test_moe_w4a8: autotuning is slow, so we test
# fewer combinations. The goal is to verify the profiler workspace allocation
# doesn't crash, not to sweep all shapes. See: github.com/flashinfer-ai/flashinfer/issues/2501
@pytest.mark.parametrize("batch_size", [1, 16])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_experts", [8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("intermediate_size", [512])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_moe_w4a8_autotune(
batch_size: int,
hidden_size: int,
num_experts: int,
top_k: int,
intermediate_size: int,
dtype: torch.dtype,
):
"""Test MoE with W4A8 quantization and autotuning enabled (regression test for #2501).

Identical to test_moe_w4a8 except:
- Kernel call is wrapped in autotune(True) to exercise the profiling path
- Smaller parameter set since autotuning is slower
"""
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use flashinfer.utils capability helpers for SM90 skips.
The direct torch.cuda.get_device_capability() check should use the project’s helper(s) for consistency and portability.

🔧 Proposed fix
-    if torch.cuda.get_device_capability()[0] != 9:
+    if get_compute_capability()[0] != 9:
         pytest.skip("W4A8 is only supported on SM90")
 from flashinfer import (
     autotune,
     fp4_quantize,
     mxfp4_dequantize,
     mxfp4_quantize,
     mxfp8_dequantize_host,
     mxfp8_quantize,
     mxfp4_dequantize_host,
 )
+from flashinfer.utils import get_compute_capability

As per coding guidelines, Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
if get_compute_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
Suggested change
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
from flashinfer import (
autotune,
fp4_quantize,
mxfp4_dequantize,
mxfp4_quantize,
mxfp8_dequantize_host,
mxfp8_quantize,
mxfp4_dequantize_host,
)
from flashinfer.utils import get_compute_capability
🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1659 - 1661, Replace
the direct torch.cuda.get_device_capability() check with the project's
helper(s): import and call is_sm90a_supported() (or use get_compute_capability()
if you need major/minor) and use pytest.skip when unsupported; specifically
change the conditional that currently uses torch.cuda.get_device_capability()[0]
!= 9 to use is_sm90a_supported() (or check get_compute_capability()) and keep
the existing pytest.skip message, leaving the subsequent top_k > num_experts
logic unchanged.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the entire moe w4a8 test as-is and just wrapped it in the autotuner, if you'd like the underlying code can be encapsulated in a function, up to you.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ipnon, thanks for the context! You're absolutely right that copying the test as-is makes sense for consistency — I can see the entire file uses torch.cuda.get_device_capability() directly throughout (lines 808, 990, 1223, 1290, 1549, etc.).

My suggestion was based on the coding guideline about using flashinfer.utils helpers like get_compute_capability(), but since this pattern exists file-wide, changing just the new test would create inconsistency. Fixing all occurrences would be out of scope for this bug fix PR.

Feel free to keep the current approach for consistency with the rest of the file. As for refactoring the shared logic into a function — that's a good idea but probably better suited for a separate cleanup PR rather than mixing it with this bug fix.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-02-02T07:46:39.459Z
Learning: Applies to tests/**/*.py : Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures

Learnt from: yzh119
Repo: flashinfer-ai/flashinfer PR: 2370
File: tests/gdn/conftest.py:25-34
Timestamp: 2026-01-21T21:26:07.831Z
Learning: FlashInfer does not support CPU-only environments. Tests assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures.

pytest.skip("top_k must be <= num_experts")

torch.manual_seed(42)
group_size = 128
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
affine_coeff = 0.005

x = torch.randn(m, k, dtype=dtype, device="cuda")
router_logits = torch.randn(m, e, dtype=dtype, device="cuda")
w1_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda")
w2_weight = torch.randint(0, 256, (e, k, n // 2), dtype=torch.uint8, device="cuda")
w3_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda")

w1_scale = (
torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff
)
w2_scale = (
torch.randn(e, k, n // group_size, dtype=dtype, device="cuda") * affine_coeff
)
w3_scale = (
torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff
)

w1_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95
w2_pre_quant_scale = torch.rand(e, n, dtype=dtype, device="cuda") * 0.1 + 0.95
w3_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95

input_scale = torch.rand(e, 1, dtype=torch.float32, device="cuda") * 0.2 + 0.1
weight_scale_2 = torch.ones(e, 1, dtype=torch.float32, device="cuda")

fc1_weights = torch.cat([w3_weight, w1_weight], dim=1)
fc2_weights = w2_weight

def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor:
interleave_factor = 4 if dim % 512 == 0 else (2 if dim % 256 == 0 else 1)
s = w.shape
w_interleaved = (
w.reshape(s[0], s[1], s[2] // interleave_factor, interleave_factor)
.permute(0, 2, 1, 3)
.reshape(s[0], s[2] // interleave_factor, s[1] * interleave_factor)
.contiguous()
)
return w_interleaved

w3_w1_scales = torch.cat([w3_scale, w1_scale], dim=1)
w3_w1_scales_int = interleave_weights(w3_w1_scales, k)
w2_scales_int = interleave_weights(w2_scale, n)

w3_w1_pre_quant_max = torch.max(w1_pre_quant_scale, w3_pre_quant_scale)
w3_w1_input_scale_max = input_scale.max()
fc31_act_scale = (w3_w1_pre_quant_max / w3_w1_input_scale_max).to(dtype)
fc2_act_scale = (w2_pre_quant_scale / input_scale).to(dtype).unsqueeze(-1)

fc31_alpha = (weight_scale_2.squeeze(-1) * w3_w1_input_scale_max).float()
fc2_alpha = (weight_scale_2.squeeze(-1) * input_scale.squeeze(-1)).float()

zero_1 = torch.empty(0, dtype=dtype, device="cuda")
zero_2 = torch.empty(0, dtype=dtype, device="cuda")

sm = (
torch.cuda.get_device_capability()[0] * 10
+ torch.cuda.get_device_capability()[1]
)
if sm >= 90:
w3_w1_scales_out = w3_w1_scales_int.to(torch.bfloat16).view(dtype)
w2_scales_out = w2_scales_int.to(torch.bfloat16).view(dtype)
fc31_act_out = fc31_act_scale.to(torch.bfloat16).view(dtype)
fc2_act_out = fc2_act_scale.to(torch.bfloat16).view(dtype)
else:
w3_w1_scales_out = w3_w1_scales_int.to(dtype)
w2_scales_out = w2_scales_int.to(dtype)
fc31_act_out = fc31_act_scale
fc2_act_out = fc2_act_scale

quant_scales = (
w3_w1_scales_out,
w2_scales_out,
fc31_act_out,
fc2_act_out,
zero_1,
zero_2,
fc31_alpha,
fc2_alpha,
)

routing_weights, selected_experts = compute_routing(router_logits, top_k)
selected_experts_int32 = selected_experts.to(torch.int32)

flash_output = torch.zeros_like(x)
with autotune(True):
_ = fused_moe.cutlass_fused_moe(
x,
selected_experts_int32,
routing_weights,
fc1_weights.view(torch.uint8),
fc2_weights.view(torch.uint8),
dtype,
quant_scales=quant_scales,
use_w4_group_scaling=True,
output=flash_output,
use_packed_weights=True,
)

w31_weight_list = []
w2_weight_list = []

for e_idx in range(num_experts):
w1_w = w1_weight[e_idx]
w3_w = w3_weight[e_idx]
w2_w = w2_weight[e_idx]
w1_s = w1_scale[e_idx]
w3_s = w3_scale[e_idx]
w2_s = w2_scale[e_idx]
ws2 = weight_scale_2[e_idx]

w1_dequant = dequantize_int4_to_dtype(w1_w, w1_s, group_size, dtype, ws2)
w3_dequant = dequantize_int4_to_dtype(w3_w, w3_s, group_size, dtype, ws2)
w2_dequant = dequantize_int4_to_dtype(w2_w, w2_s, group_size, dtype, ws2)

w31 = torch.cat([w3_dequant, w1_dequant], dim=0)

w31_weight_list.append(w31)
w2_weight_list.append(w2_dequant)

w31_weight_dequant = torch.stack(w31_weight_list, dim=0)
w2_weight_dequant = torch.stack(w2_weight_list, dim=0)

ref_output = torch_moe_w4a8(
num_experts,
x,
w31_weight_dequant,
w2_weight_dequant,
selected_experts,
routing_weights,
fc1_input_scale=input_scale.squeeze(-1),
fc2_input_scale=input_scale.squeeze(-1),
fc1_pre_quant_scale=torch.max(w1_pre_quant_scale, w3_pre_quant_scale),
fc2_pre_quant_scale=w2_pre_quant_scale,
fc1_weight_scale_2=weight_scale_2.squeeze(-1),
fc2_weight_scale_2=weight_scale_2.squeeze(-1),
)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new test function test_moe_w4a8_autotune has a significant amount of duplicated code from the existing test_moe_w4a8 test. To improve maintainability and reduce redundancy, consider the following refactorings:

  1. Extract interleave_weights: The interleave_weights helper function is defined inside both test functions. It could be extracted to the module level to be shared between them.
  2. Create a common test helper: The main logic of both tests is identical, except for the autotune context manager and the parameter sets. You could create a helper function that takes an autotune_enabled flag and contains the common test logic. Then, test_moe_w4a8 and test_moe_w4a8_autotune can simply call this helper with the appropriate flag and their respective pytest parameterizations.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, just let me know your preference and I'll update.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you adopt the suggestion in (2.)? In the original test_moe_w4a8 test, we can add an additional parameter autotune and when autotune=True we can wrap the call with the autotune context.



if __name__ == "__main__":
pytest.main([__file__, "-v"])