Skip to content

fix: W4A8 autotune crash in cutlass_fused_moe profiler workspace#2564

Merged
jimmyzho merged 2 commits intoflashinfer-ai:mainfrom
ipnon:fix/w4a8-autotune-profiler-workspace-2501
Feb 18, 2026
Merged

fix: W4A8 autotune crash in cutlass_fused_moe profiler workspace#2564
jimmyzho merged 2 commits intoflashinfer-ai:mainfrom
ipnon:fix/w4a8-autotune-profiler-workspace-2501

Conversation

@ipnon
Copy link
Contributor

@ipnon ipnon commented Feb 14, 2026

Description

Problem

cutlass_fused_moe crashes with RuntimeError: Assertion failed: quant_1 && quant_2 when called inside autotune() with W4A8 quantization. Other quant modes (FP8 block scaling, BF16+MXFP4) are unaffected.

Root Cause

getProfilerWorkspaces() in cutlass_fused_moe_kernels.cuh does not recognize kUINT8 as an integer-quantized weight type. W4A8 uses kUINT8 (packed INT4 pairs), but the three type-check sites only match kINT8 and kINT4:

  1. is_int_w_quant — controls per-tensor scale buffer allocation
  2. is_int_groupwise_w_quant — controls groupwise scale buffer allocation
  3. dtype_bytes ternary — controls scale factor byte width

With kUINT8 unrecognized, all three evaluate false, so the profiler allocates zero-size quant buffers. When prepareQuantParams() later tries to use them, it asserts on the resulting nullptrs.

The consumer side (prepareQuantParams) already handles kUINT8 correctly — only the allocator side was missing it. Notably, is_wfp4a16_quant in the same function does check for kUINT8, but only when activations are HALF/BF16 — not FP8. So the workspace code was partially aware of kUINT8; the gap was specifically in the FP8 activation pairing that W4A8 uses.

Fix

Add kUINT8 to all three sites in getProfilerWorkspaces(), matching the existing kINT4 behavior (since kUINT8 is the packed representation of INT4 weights).

This cascades correctly through the existing logic:

  • is_int_groupwise_w_quant becomes true (for groupwise path)
  • is_w4afp8_quant = is_int_groupwise_w_quant && is_fp8_act_quant becomes true
  • w4a8_alpha_size = num_experts_per_node * sizeof(float) becomes non-zero
  • ADD(w4a8_alpha) registers the buffer in the workspace map
  • GET_WS_PTR(float const*, w4a8_alpha) in prepareQuantParams() returns a valid pointer instead of nullptr

Testing

Added test_moe_w4a8_autotune regression test — identical to test_moe_w4a8 but wrapped in autotune(True) with a reduced parameter set (autotuning is slow due to JIT compilation of 163 SM90 CUTLASS kernels).

Tested on H100 (SM90): 2 passed in 25:14.

Related Issues

Fixes #2501

Pull Request Checklist

Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Root cause class: allocator-consumer mismatch from missing case in non-exhaustive type dispatch. prepareQuantParams() was updated for W4A8 support but getProfilerWorkspaces() was not — the two sites use independent boolean checks rather than a shared dispatch table, so adding a new weight type requires updating both manually.

Summary by CodeRabbit

  • New Features

    • Added support for UINT8 weight quantization in CUTLASS MoE backend for improved model compatibility
  • Tests

    • Enhanced MoE test suite with optional autotuning execution path for better performance validation

…shinfer-ai#2501)

getProfilerWorkspaces() did not recognize kUINT8 as an integer quant
weight type, so it allocated zero-size buffers for quant params.
prepareQuantParams() then asserted on the resulting nullptrs.

Three changes:
- Add kUINT8 to is_int_w_quant and is_int_groupwise_w_quant booleans
- Add kUINT8 to dtype_bytes ternary so scale factor buffers use output
  type size (matching kINT4 behavior, since kUINT8 is packed INT4)
- Add test_moe_w4a8_autotune regression test

Co-Authored-By: Claude Opus 4.6 <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 14, 2026

📝 Walkthrough

Walkthrough

Widened fused MOE profiler logic to treat UINT8 weights like other integer weight types (INT4/INT8) for dtype size selection and quantization flags; added a test path to run the W4A8 fused_moe autotune scenario.

Changes

Cohort / File(s) Summary
Profiler UINT8 handling
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Include kUINT8 in dtype byte-size selection and in integer-quantization checks (is_int_w_quant, is_int_groupwise_w_quant), expanding how UINT8 weights are classified during profiling.
W4A8 autotune test
tests/moe/test_trtllm_cutlass_fused_moe.py
Add use_autotune parameter and run the W4A8 test under an autotune(True) context when enabled to exercise the autotuned profiling path for W4A8.

Sequence Diagram(s)

(omitted — changes are localized algorithmic adjustments and a test addition; no multi-component flow diagram required)

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • djmmoss
  • cyx-6
  • yongwww
  • bkryu

Poem

🐰 I nibble bytes both big and small,
UINT8 joins the quantized ball.
Profiling hops, autotune sings,
W4A8 ready — flappy wings.
The rabbit cheers: no more asserts, all's well! ✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix: W4A8 autotune crash in cutlass_fused_moe profiler workspace' clearly and concisely summarizes the main change: fixing a W4A8 autotune crash in the profiler workspace.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering problem statement, root cause analysis, fix details, and testing; all required template sections are addressed with detailed explanations.
Linked Issues check ✅ Passed The code changes directly address issue #2501 by adding kUINT8 recognition to three sites in getProfilerWorkspaces(), enabling proper W4A8 quant buffer allocation during autotune, and a regression test was added to verify the fix.
Out of Scope Changes check ✅ Passed All changes are in-scope: modifications to cutlass_fused_moe_kernels.cuh handle the W4A8 allocator-side type dispatch gap, and test modifications add W4A8 autotune coverage to verify the fix.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @ipnon, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical crash occurring in the cutlass_fused_moe profiler when attempting to autotune with W4A8 quantization. The issue stemmed from the profiler's workspace allocation logic failing to correctly identify kUINT8 as an integer-quantized weight type, leading to improper buffer sizing and subsequent runtime errors. The solution involves extending the type recognition in the getProfilerWorkspaces() function to include kUINT8, thereby ensuring correct buffer allocation. A new regression test has been added to validate the fix and prevent future regressions.

Highlights

  • Fixes W4A8 Autotune Crash: Resolved a RuntimeError in cutlass_fused_moe when using W4A8 quantization within autotune() due to incorrect buffer allocation.
  • Corrected Quantization Buffer Allocation: Updated getProfilerWorkspaces() to properly recognize kUINT8 as a valid integer-quantized weight type, preventing the allocation of zero-sized quant buffers.
  • Introduced Regression Test: Added test_moe_w4a8_autotune to specifically verify the fix for W4A8 autotuning and prevent future regressions, using a reduced parameter set for efficiency.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
    • Modified dtype_bytes calculation to include nvinfer1::DataType::kUINT8 when determining scale factor byte width.
    • Updated is_int_w_quant and is_int_groupwise_w_quant boolean checks to recognize nvinfer1::DataType::kUINT8 for integer weight quantization.
  • tests/moe/test_trtllm_cutlass_fused_moe.py
    • Imported autotune from flashinfer.
    • Added test_moe_w4a8_autotune to specifically test W4A8 quantization with autotuning enabled, using a reduced parameter set for efficiency.
Activity
  • The author identified the problem, root cause, and implemented a fix.
  • Pre-commit checks were installed and run, with reported issues fixed.
  • New tests were added, and all existing tests are passing.
  • The author provided detailed reviewer notes explaining the root cause as an allocator-consumer mismatch from a missing case in non-exhaustive type dispatch.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash in the cutlass_fused_moe profiler for W4A8 quantization by correctly handling the kUINT8 data type during workspace allocation. The fix is sound and correctly identifies and resolves the issue. The addition of the test_moe_w4a8_autotune regression test is excellent for ensuring this bug does not reappear. My review includes suggestions to refactor duplicated code in both the C++ implementation and the new Python test to improve maintainability.

Comment on lines 4285 to 4292
bool is_int_w_quant =
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4) &&
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 ||
mWType == nvinfer1::DataType::kUINT8) &&
mGroupSize <= 0;
bool is_int_groupwise_w_quant =
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4) &&
(mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 ||
mWType == nvinfer1::DataType::kUINT8) &&
mGroupSize > 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and avoid repeating the same condition, you could introduce a boolean variable for the integer quantization type check.

  bool is_int_quant_type = (mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 ||
                            mWType == nvinfer1::DataType::kUINT8);
  bool is_int_w_quant = is_int_quant_type && mGroupSize <= 0;
  bool is_int_groupwise_w_quant = is_int_quant_type && mGroupSize > 0;




Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with implementing this if you want but I didn't want to override the established style too much.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable

simplifying conditions and adding the immediate var names for readability is encouraged (tho we don't strictly impose the style). thx

Comment on lines 1645 to 1806
def test_moe_w4a8_autotune(
batch_size: int,
hidden_size: int,
num_experts: int,
top_k: int,
intermediate_size: int,
dtype: torch.dtype,
):
"""Test MoE with W4A8 quantization and autotuning enabled (regression test for #2501).

Identical to test_moe_w4a8 except:
- Kernel call is wrapped in autotune(True) to exercise the profiling path
- Smaller parameter set since autotuning is slower
"""
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
pytest.skip("top_k must be <= num_experts")

torch.manual_seed(42)
group_size = 128
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
affine_coeff = 0.005

x = torch.randn(m, k, dtype=dtype, device="cuda")
router_logits = torch.randn(m, e, dtype=dtype, device="cuda")
w1_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda")
w2_weight = torch.randint(0, 256, (e, k, n // 2), dtype=torch.uint8, device="cuda")
w3_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda")

w1_scale = (
torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff
)
w2_scale = (
torch.randn(e, k, n // group_size, dtype=dtype, device="cuda") * affine_coeff
)
w3_scale = (
torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff
)

w1_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95
w2_pre_quant_scale = torch.rand(e, n, dtype=dtype, device="cuda") * 0.1 + 0.95
w3_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95

input_scale = torch.rand(e, 1, dtype=torch.float32, device="cuda") * 0.2 + 0.1
weight_scale_2 = torch.ones(e, 1, dtype=torch.float32, device="cuda")

fc1_weights = torch.cat([w3_weight, w1_weight], dim=1)
fc2_weights = w2_weight

def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor:
interleave_factor = 4 if dim % 512 == 0 else (2 if dim % 256 == 0 else 1)
s = w.shape
w_interleaved = (
w.reshape(s[0], s[1], s[2] // interleave_factor, interleave_factor)
.permute(0, 2, 1, 3)
.reshape(s[0], s[2] // interleave_factor, s[1] * interleave_factor)
.contiguous()
)
return w_interleaved

w3_w1_scales = torch.cat([w3_scale, w1_scale], dim=1)
w3_w1_scales_int = interleave_weights(w3_w1_scales, k)
w2_scales_int = interleave_weights(w2_scale, n)

w3_w1_pre_quant_max = torch.max(w1_pre_quant_scale, w3_pre_quant_scale)
w3_w1_input_scale_max = input_scale.max()
fc31_act_scale = (w3_w1_pre_quant_max / w3_w1_input_scale_max).to(dtype)
fc2_act_scale = (w2_pre_quant_scale / input_scale).to(dtype).unsqueeze(-1)

fc31_alpha = (weight_scale_2.squeeze(-1) * w3_w1_input_scale_max).float()
fc2_alpha = (weight_scale_2.squeeze(-1) * input_scale.squeeze(-1)).float()

zero_1 = torch.empty(0, dtype=dtype, device="cuda")
zero_2 = torch.empty(0, dtype=dtype, device="cuda")

sm = (
torch.cuda.get_device_capability()[0] * 10
+ torch.cuda.get_device_capability()[1]
)
if sm >= 90:
w3_w1_scales_out = w3_w1_scales_int.to(torch.bfloat16).view(dtype)
w2_scales_out = w2_scales_int.to(torch.bfloat16).view(dtype)
fc31_act_out = fc31_act_scale.to(torch.bfloat16).view(dtype)
fc2_act_out = fc2_act_scale.to(torch.bfloat16).view(dtype)
else:
w3_w1_scales_out = w3_w1_scales_int.to(dtype)
w2_scales_out = w2_scales_int.to(dtype)
fc31_act_out = fc31_act_scale
fc2_act_out = fc2_act_scale

quant_scales = (
w3_w1_scales_out,
w2_scales_out,
fc31_act_out,
fc2_act_out,
zero_1,
zero_2,
fc31_alpha,
fc2_alpha,
)

routing_weights, selected_experts = compute_routing(router_logits, top_k)
selected_experts_int32 = selected_experts.to(torch.int32)

flash_output = torch.zeros_like(x)
with autotune(True):
_ = fused_moe.cutlass_fused_moe(
x,
selected_experts_int32,
routing_weights,
fc1_weights.view(torch.uint8),
fc2_weights.view(torch.uint8),
dtype,
quant_scales=quant_scales,
use_w4_group_scaling=True,
output=flash_output,
use_packed_weights=True,
)

w31_weight_list = []
w2_weight_list = []

for e_idx in range(num_experts):
w1_w = w1_weight[e_idx]
w3_w = w3_weight[e_idx]
w2_w = w2_weight[e_idx]
w1_s = w1_scale[e_idx]
w3_s = w3_scale[e_idx]
w2_s = w2_scale[e_idx]
ws2 = weight_scale_2[e_idx]

w1_dequant = dequantize_int4_to_dtype(w1_w, w1_s, group_size, dtype, ws2)
w3_dequant = dequantize_int4_to_dtype(w3_w, w3_s, group_size, dtype, ws2)
w2_dequant = dequantize_int4_to_dtype(w2_w, w2_s, group_size, dtype, ws2)

w31 = torch.cat([w3_dequant, w1_dequant], dim=0)

w31_weight_list.append(w31)
w2_weight_list.append(w2_dequant)

w31_weight_dequant = torch.stack(w31_weight_list, dim=0)
w2_weight_dequant = torch.stack(w2_weight_list, dim=0)

ref_output = torch_moe_w4a8(
num_experts,
x,
w31_weight_dequant,
w2_weight_dequant,
selected_experts,
routing_weights,
fc1_input_scale=input_scale.squeeze(-1),
fc2_input_scale=input_scale.squeeze(-1),
fc1_pre_quant_scale=torch.max(w1_pre_quant_scale, w3_pre_quant_scale),
fc2_pre_quant_scale=w2_pre_quant_scale,
fc1_weight_scale_2=weight_scale_2.squeeze(-1),
fc2_weight_scale_2=weight_scale_2.squeeze(-1),
)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new test function test_moe_w4a8_autotune has a significant amount of duplicated code from the existing test_moe_w4a8 test. To improve maintainability and reduce redundancy, consider the following refactorings:

  1. Extract interleave_weights: The interleave_weights helper function is defined inside both test functions. It could be extracted to the module level to be shared between them.
  2. Create a common test helper: The main logic of both tests is identical, except for the autotune context manager and the parameter sets. You could create a helper function that takes an autotune_enabled flag and contains the common test logic. Then, test_moe_w4a8 and test_moe_w4a8_autotune can simply call this helper with the appropriate flag and their respective pytest parameterizations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, just let me know your preference and I'll update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you adopt the suggestion in (2.)? In the original test_moe_w4a8 test, we can add an additional parameter autotune and when autotune=True we can wrap the call with the autotune context.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1659-1661: Replace the direct torch.cuda.get_device_capability()
check with the project's helper(s): import and call is_sm90a_supported() (or use
get_compute_capability() if you need major/minor) and use pytest.skip when
unsupported; specifically change the conditional that currently uses
torch.cuda.get_device_capability()[0] != 9 to use is_sm90a_supported() (or check
get_compute_capability()) and keep the existing pytest.skip message, leaving the
subsequent top_k > num_experts logic unchanged.

Comment on lines 1659 to 1661
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
Copy link
Contributor

@coderabbitai coderabbitai bot Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use flashinfer.utils capability helpers for SM90 skips.
The direct torch.cuda.get_device_capability() check should use the project’s helper(s) for consistency and portability.

🔧 Proposed fix
-    if torch.cuda.get_device_capability()[0] != 9:
+    if get_compute_capability()[0] != 9:
         pytest.skip("W4A8 is only supported on SM90")
 from flashinfer import (
     autotune,
     fp4_quantize,
     mxfp4_dequantize,
     mxfp4_quantize,
     mxfp8_dequantize_host,
     mxfp8_quantize,
     mxfp4_dequantize_host,
 )
+from flashinfer.utils import get_compute_capability

As per coding guidelines, Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
if get_compute_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
Suggested change
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
from flashinfer import (
autotune,
fp4_quantize,
mxfp4_dequantize,
mxfp4_quantize,
mxfp8_dequantize_host,
mxfp8_quantize,
mxfp4_dequantize_host,
)
from flashinfer.utils import get_compute_capability
🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1659 - 1661, Replace
the direct torch.cuda.get_device_capability() check with the project's
helper(s): import and call is_sm90a_supported() (or use get_compute_capability()
if you need major/minor) and use pytest.skip when unsupported; specifically
change the conditional that currently uses torch.cuda.get_device_capability()[0]
!= 9 to use is_sm90a_supported() (or check get_compute_capability()) and keep
the existing pytest.skip message, leaving the subsequent top_k > num_experts
logic unchanged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the entire moe w4a8 test as-is and just wrapped it in the autotuner, if you'd like the underlying code can be encapsulated in a function, up to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ipnon, thanks for the context! You're absolutely right that copying the test as-is makes sense for consistency — I can see the entire file uses torch.cuda.get_device_capability() directly throughout (lines 808, 990, 1223, 1290, 1549, etc.).

My suggestion was based on the coding guideline about using flashinfer.utils helpers like get_compute_capability(), but since this pattern exists file-wide, changing just the new test would create inconsistency. Fixing all occurrences would be out of scope for this bug fix PR.

Feel free to keep the current approach for consistency with the rest of the file. As for refactoring the shared logic into a function — that's a good idea but probably better suited for a separate cleanup PR rather than mixing it with this bug fix.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-02-02T07:46:39.459Z
Learning: Applies to tests/**/*.py : Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures

Learnt from: yzh119
Repo: flashinfer-ai/flashinfer PR: 2370
File: tests/gdn/conftest.py:25-34
Timestamp: 2026-01-21T21:26:07.831Z
Learning: FlashInfer does not support CPU-only environments. Tests assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures.

@ipnon
Copy link
Contributor Author

ipnon commented Feb 14, 2026

This if my first PR here, I attempted to thoroughly review the problem and steelman any potential review issues, let me know what you think, thanks.

@aleozlx aleozlx self-assigned this Feb 14, 2026
Copy link
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, just a small nit in the test. Thanks for the fix!

Comment on lines 1645 to 1806
def test_moe_w4a8_autotune(
batch_size: int,
hidden_size: int,
num_experts: int,
top_k: int,
intermediate_size: int,
dtype: torch.dtype,
):
"""Test MoE with W4A8 quantization and autotuning enabled (regression test for #2501).

Identical to test_moe_w4a8 except:
- Kernel call is wrapped in autotune(True) to exercise the profiling path
- Smaller parameter set since autotuning is slower
"""
if torch.cuda.get_device_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90")
if top_k > num_experts:
pytest.skip("top_k must be <= num_experts")

torch.manual_seed(42)
group_size = 128
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
affine_coeff = 0.005

x = torch.randn(m, k, dtype=dtype, device="cuda")
router_logits = torch.randn(m, e, dtype=dtype, device="cuda")
w1_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda")
w2_weight = torch.randint(0, 256, (e, k, n // 2), dtype=torch.uint8, device="cuda")
w3_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda")

w1_scale = (
torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff
)
w2_scale = (
torch.randn(e, k, n // group_size, dtype=dtype, device="cuda") * affine_coeff
)
w3_scale = (
torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff
)

w1_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95
w2_pre_quant_scale = torch.rand(e, n, dtype=dtype, device="cuda") * 0.1 + 0.95
w3_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95

input_scale = torch.rand(e, 1, dtype=torch.float32, device="cuda") * 0.2 + 0.1
weight_scale_2 = torch.ones(e, 1, dtype=torch.float32, device="cuda")

fc1_weights = torch.cat([w3_weight, w1_weight], dim=1)
fc2_weights = w2_weight

def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor:
interleave_factor = 4 if dim % 512 == 0 else (2 if dim % 256 == 0 else 1)
s = w.shape
w_interleaved = (
w.reshape(s[0], s[1], s[2] // interleave_factor, interleave_factor)
.permute(0, 2, 1, 3)
.reshape(s[0], s[2] // interleave_factor, s[1] * interleave_factor)
.contiguous()
)
return w_interleaved

w3_w1_scales = torch.cat([w3_scale, w1_scale], dim=1)
w3_w1_scales_int = interleave_weights(w3_w1_scales, k)
w2_scales_int = interleave_weights(w2_scale, n)

w3_w1_pre_quant_max = torch.max(w1_pre_quant_scale, w3_pre_quant_scale)
w3_w1_input_scale_max = input_scale.max()
fc31_act_scale = (w3_w1_pre_quant_max / w3_w1_input_scale_max).to(dtype)
fc2_act_scale = (w2_pre_quant_scale / input_scale).to(dtype).unsqueeze(-1)

fc31_alpha = (weight_scale_2.squeeze(-1) * w3_w1_input_scale_max).float()
fc2_alpha = (weight_scale_2.squeeze(-1) * input_scale.squeeze(-1)).float()

zero_1 = torch.empty(0, dtype=dtype, device="cuda")
zero_2 = torch.empty(0, dtype=dtype, device="cuda")

sm = (
torch.cuda.get_device_capability()[0] * 10
+ torch.cuda.get_device_capability()[1]
)
if sm >= 90:
w3_w1_scales_out = w3_w1_scales_int.to(torch.bfloat16).view(dtype)
w2_scales_out = w2_scales_int.to(torch.bfloat16).view(dtype)
fc31_act_out = fc31_act_scale.to(torch.bfloat16).view(dtype)
fc2_act_out = fc2_act_scale.to(torch.bfloat16).view(dtype)
else:
w3_w1_scales_out = w3_w1_scales_int.to(dtype)
w2_scales_out = w2_scales_int.to(dtype)
fc31_act_out = fc31_act_scale
fc2_act_out = fc2_act_scale

quant_scales = (
w3_w1_scales_out,
w2_scales_out,
fc31_act_out,
fc2_act_out,
zero_1,
zero_2,
fc31_alpha,
fc2_alpha,
)

routing_weights, selected_experts = compute_routing(router_logits, top_k)
selected_experts_int32 = selected_experts.to(torch.int32)

flash_output = torch.zeros_like(x)
with autotune(True):
_ = fused_moe.cutlass_fused_moe(
x,
selected_experts_int32,
routing_weights,
fc1_weights.view(torch.uint8),
fc2_weights.view(torch.uint8),
dtype,
quant_scales=quant_scales,
use_w4_group_scaling=True,
output=flash_output,
use_packed_weights=True,
)

w31_weight_list = []
w2_weight_list = []

for e_idx in range(num_experts):
w1_w = w1_weight[e_idx]
w3_w = w3_weight[e_idx]
w2_w = w2_weight[e_idx]
w1_s = w1_scale[e_idx]
w3_s = w3_scale[e_idx]
w2_s = w2_scale[e_idx]
ws2 = weight_scale_2[e_idx]

w1_dequant = dequantize_int4_to_dtype(w1_w, w1_s, group_size, dtype, ws2)
w3_dequant = dequantize_int4_to_dtype(w3_w, w3_s, group_size, dtype, ws2)
w2_dequant = dequantize_int4_to_dtype(w2_w, w2_s, group_size, dtype, ws2)

w31 = torch.cat([w3_dequant, w1_dequant], dim=0)

w31_weight_list.append(w31)
w2_weight_list.append(w2_dequant)

w31_weight_dequant = torch.stack(w31_weight_list, dim=0)
w2_weight_dequant = torch.stack(w2_weight_list, dim=0)

ref_output = torch_moe_w4a8(
num_experts,
x,
w31_weight_dequant,
w2_weight_dequant,
selected_experts,
routing_weights,
fc1_input_scale=input_scale.squeeze(-1),
fc2_input_scale=input_scale.squeeze(-1),
fc1_pre_quant_scale=torch.max(w1_pre_quant_scale, w3_pre_quant_scale),
fc2_pre_quant_scale=w2_pre_quant_scale,
fc1_weight_scale_2=weight_scale_2.squeeze(-1),
fc2_weight_scale_2=weight_scale_2.squeeze(-1),
)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you adopt the suggestion in (2.)? In the original test_moe_w4a8 test, we can add an additional parameter autotune and when autotune=True we can wrap the call with the autotune context.

Address review feedback from jimmyzho: instead of a separate
test_moe_w4a8_autotune function duplicating 170 lines, add a
use_autotune parametrize to the existing test and conditionally
wrap the kernel call with autotune(True).

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1481-1482: This duplicate direct call to
torch.cuda.get_device_capability() (the if block that calls pytest.skip("W4A8 is
only supported on SM90")) should be made consistent with the rest of the file:
remove this isolated check and replace it with the shared helper or predicate
used elsewhere (e.g., use the existing is_sm90()/skip_if_not_sm90() helper or
the common capability-check utility used in other tests) so all tests perform
the SM90 guard via the same function; ensure you import or reference that helper
and replace occurrences of torch.cuda.get_device_capability()[0] in this test
with that helper and keep the pytest.skip call only inside the centralized
guard.

@ipnon
Copy link
Contributor Author

ipnon commented Feb 18, 2026

Done — folded test_moe_w4a8_autotune into test_moe_w4a8 via a use_autotune parametrize. The kernel call is now conditionally wrapped with autotune(True) when use_autotune=True, no more duplicated function.

Verified on H100 (SM90) — 4/4 passed (bf16 + fp16, with and without autotune):

collected 4 items

test_moe_w4a8[False-dtype0-128-2-2-128-1] PASSED
test_moe_w4a8[False-dtype1-128-2-2-128-1] PASSED
test_moe_w4a8[True-dtype0-128-2-2-128-1]  PASSED
test_moe_w4a8[True-dtype1-128-2-2-128-1]  PASSED

@aleozlx
Copy link
Collaborator

aleozlx commented Feb 18, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !322 has been created, and the CI pipeline #44270384 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #44270384: 15/20 passed

@jimmyzho jimmyzho merged commit 127699c into flashinfer-ai:main Feb 18, 2026
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Autotune Fails for W4A8 in cutlass_fused_moe

4 participants