fix: W4A8 autotune crash in cutlass_fused_moe profiler workspace#2564
Conversation
…shinfer-ai#2501) getProfilerWorkspaces() did not recognize kUINT8 as an integer quant weight type, so it allocated zero-size buffers for quant params. prepareQuantParams() then asserted on the resulting nullptrs. Three changes: - Add kUINT8 to is_int_w_quant and is_int_groupwise_w_quant booleans - Add kUINT8 to dtype_bytes ternary so scale factor buffers use output type size (matching kINT4 behavior, since kUINT8 is packed INT4) - Add test_moe_w4a8_autotune regression test Co-Authored-By: Claude Opus 4.6 <[email protected]>
📝 WalkthroughWalkthroughWidened fused MOE profiler logic to treat Changes
Sequence Diagram(s)(omitted — changes are localized algorithmic adjustments and a test addition; no multi-component flow diagram required) Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @ipnon, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical crash occurring in the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a crash in the cutlass_fused_moe profiler for W4A8 quantization by correctly handling the kUINT8 data type during workspace allocation. The fix is sound and correctly identifies and resolves the issue. The addition of the test_moe_w4a8_autotune regression test is excellent for ensuring this bug does not reappear. My review includes suggestions to refactor duplicated code in both the C++ implementation and the new Python test to improve maintainability.
| bool is_int_w_quant = | ||
| (mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4) && | ||
| (mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 || | ||
| mWType == nvinfer1::DataType::kUINT8) && | ||
| mGroupSize <= 0; | ||
| bool is_int_groupwise_w_quant = | ||
| (mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4) && | ||
| (mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 || | ||
| mWType == nvinfer1::DataType::kUINT8) && | ||
| mGroupSize > 0; |
There was a problem hiding this comment.
To improve readability and avoid repeating the same condition, you could introduce a boolean variable for the integer quantization type check.
bool is_int_quant_type = (mWType == nvinfer1::DataType::kINT8 || mWType == nvinfer1::DataType::kINT4 ||
mWType == nvinfer1::DataType::kUINT8);
bool is_int_w_quant = is_int_quant_type && mGroupSize <= 0;
bool is_int_groupwise_w_quant = is_int_quant_type && mGroupSize > 0;
There was a problem hiding this comment.
I'm fine with implementing this if you want but I didn't want to override the established style too much.
There was a problem hiding this comment.
sounds reasonable
simplifying conditions and adding the immediate var names for readability is encouraged (tho we don't strictly impose the style). thx
| def test_moe_w4a8_autotune( | ||
| batch_size: int, | ||
| hidden_size: int, | ||
| num_experts: int, | ||
| top_k: int, | ||
| intermediate_size: int, | ||
| dtype: torch.dtype, | ||
| ): | ||
| """Test MoE with W4A8 quantization and autotuning enabled (regression test for #2501). | ||
|
|
||
| Identical to test_moe_w4a8 except: | ||
| - Kernel call is wrapped in autotune(True) to exercise the profiling path | ||
| - Smaller parameter set since autotuning is slower | ||
| """ | ||
| if torch.cuda.get_device_capability()[0] != 9: | ||
| pytest.skip("W4A8 is only supported on SM90") | ||
| if top_k > num_experts: | ||
| pytest.skip("top_k must be <= num_experts") | ||
|
|
||
| torch.manual_seed(42) | ||
| group_size = 128 | ||
| e = num_experts | ||
| m = batch_size | ||
| n = intermediate_size | ||
| k = hidden_size | ||
| affine_coeff = 0.005 | ||
|
|
||
| x = torch.randn(m, k, dtype=dtype, device="cuda") | ||
| router_logits = torch.randn(m, e, dtype=dtype, device="cuda") | ||
| w1_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda") | ||
| w2_weight = torch.randint(0, 256, (e, k, n // 2), dtype=torch.uint8, device="cuda") | ||
| w3_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda") | ||
|
|
||
| w1_scale = ( | ||
| torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff | ||
| ) | ||
| w2_scale = ( | ||
| torch.randn(e, k, n // group_size, dtype=dtype, device="cuda") * affine_coeff | ||
| ) | ||
| w3_scale = ( | ||
| torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff | ||
| ) | ||
|
|
||
| w1_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||
| w2_pre_quant_scale = torch.rand(e, n, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||
| w3_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||
|
|
||
| input_scale = torch.rand(e, 1, dtype=torch.float32, device="cuda") * 0.2 + 0.1 | ||
| weight_scale_2 = torch.ones(e, 1, dtype=torch.float32, device="cuda") | ||
|
|
||
| fc1_weights = torch.cat([w3_weight, w1_weight], dim=1) | ||
| fc2_weights = w2_weight | ||
|
|
||
| def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor: | ||
| interleave_factor = 4 if dim % 512 == 0 else (2 if dim % 256 == 0 else 1) | ||
| s = w.shape | ||
| w_interleaved = ( | ||
| w.reshape(s[0], s[1], s[2] // interleave_factor, interleave_factor) | ||
| .permute(0, 2, 1, 3) | ||
| .reshape(s[0], s[2] // interleave_factor, s[1] * interleave_factor) | ||
| .contiguous() | ||
| ) | ||
| return w_interleaved | ||
|
|
||
| w3_w1_scales = torch.cat([w3_scale, w1_scale], dim=1) | ||
| w3_w1_scales_int = interleave_weights(w3_w1_scales, k) | ||
| w2_scales_int = interleave_weights(w2_scale, n) | ||
|
|
||
| w3_w1_pre_quant_max = torch.max(w1_pre_quant_scale, w3_pre_quant_scale) | ||
| w3_w1_input_scale_max = input_scale.max() | ||
| fc31_act_scale = (w3_w1_pre_quant_max / w3_w1_input_scale_max).to(dtype) | ||
| fc2_act_scale = (w2_pre_quant_scale / input_scale).to(dtype).unsqueeze(-1) | ||
|
|
||
| fc31_alpha = (weight_scale_2.squeeze(-1) * w3_w1_input_scale_max).float() | ||
| fc2_alpha = (weight_scale_2.squeeze(-1) * input_scale.squeeze(-1)).float() | ||
|
|
||
| zero_1 = torch.empty(0, dtype=dtype, device="cuda") | ||
| zero_2 = torch.empty(0, dtype=dtype, device="cuda") | ||
|
|
||
| sm = ( | ||
| torch.cuda.get_device_capability()[0] * 10 | ||
| + torch.cuda.get_device_capability()[1] | ||
| ) | ||
| if sm >= 90: | ||
| w3_w1_scales_out = w3_w1_scales_int.to(torch.bfloat16).view(dtype) | ||
| w2_scales_out = w2_scales_int.to(torch.bfloat16).view(dtype) | ||
| fc31_act_out = fc31_act_scale.to(torch.bfloat16).view(dtype) | ||
| fc2_act_out = fc2_act_scale.to(torch.bfloat16).view(dtype) | ||
| else: | ||
| w3_w1_scales_out = w3_w1_scales_int.to(dtype) | ||
| w2_scales_out = w2_scales_int.to(dtype) | ||
| fc31_act_out = fc31_act_scale | ||
| fc2_act_out = fc2_act_scale | ||
|
|
||
| quant_scales = ( | ||
| w3_w1_scales_out, | ||
| w2_scales_out, | ||
| fc31_act_out, | ||
| fc2_act_out, | ||
| zero_1, | ||
| zero_2, | ||
| fc31_alpha, | ||
| fc2_alpha, | ||
| ) | ||
|
|
||
| routing_weights, selected_experts = compute_routing(router_logits, top_k) | ||
| selected_experts_int32 = selected_experts.to(torch.int32) | ||
|
|
||
| flash_output = torch.zeros_like(x) | ||
| with autotune(True): | ||
| _ = fused_moe.cutlass_fused_moe( | ||
| x, | ||
| selected_experts_int32, | ||
| routing_weights, | ||
| fc1_weights.view(torch.uint8), | ||
| fc2_weights.view(torch.uint8), | ||
| dtype, | ||
| quant_scales=quant_scales, | ||
| use_w4_group_scaling=True, | ||
| output=flash_output, | ||
| use_packed_weights=True, | ||
| ) | ||
|
|
||
| w31_weight_list = [] | ||
| w2_weight_list = [] | ||
|
|
||
| for e_idx in range(num_experts): | ||
| w1_w = w1_weight[e_idx] | ||
| w3_w = w3_weight[e_idx] | ||
| w2_w = w2_weight[e_idx] | ||
| w1_s = w1_scale[e_idx] | ||
| w3_s = w3_scale[e_idx] | ||
| w2_s = w2_scale[e_idx] | ||
| ws2 = weight_scale_2[e_idx] | ||
|
|
||
| w1_dequant = dequantize_int4_to_dtype(w1_w, w1_s, group_size, dtype, ws2) | ||
| w3_dequant = dequantize_int4_to_dtype(w3_w, w3_s, group_size, dtype, ws2) | ||
| w2_dequant = dequantize_int4_to_dtype(w2_w, w2_s, group_size, dtype, ws2) | ||
|
|
||
| w31 = torch.cat([w3_dequant, w1_dequant], dim=0) | ||
|
|
||
| w31_weight_list.append(w31) | ||
| w2_weight_list.append(w2_dequant) | ||
|
|
||
| w31_weight_dequant = torch.stack(w31_weight_list, dim=0) | ||
| w2_weight_dequant = torch.stack(w2_weight_list, dim=0) | ||
|
|
||
| ref_output = torch_moe_w4a8( | ||
| num_experts, | ||
| x, | ||
| w31_weight_dequant, | ||
| w2_weight_dequant, | ||
| selected_experts, | ||
| routing_weights, | ||
| fc1_input_scale=input_scale.squeeze(-1), | ||
| fc2_input_scale=input_scale.squeeze(-1), | ||
| fc1_pre_quant_scale=torch.max(w1_pre_quant_scale, w3_pre_quant_scale), | ||
| fc2_pre_quant_scale=w2_pre_quant_scale, | ||
| fc1_weight_scale_2=weight_scale_2.squeeze(-1), | ||
| fc2_weight_scale_2=weight_scale_2.squeeze(-1), | ||
| ) | ||
| torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1) |
There was a problem hiding this comment.
This new test function test_moe_w4a8_autotune has a significant amount of duplicated code from the existing test_moe_w4a8 test. To improve maintainability and reduce redundancy, consider the following refactorings:
- Extract
interleave_weights: Theinterleave_weightshelper function is defined inside both test functions. It could be extracted to the module level to be shared between them. - Create a common test helper: The main logic of both tests is identical, except for the
autotunecontext manager and the parameter sets. You could create a helper function that takes anautotune_enabledflag and contains the common test logic. Then,test_moe_w4a8andtest_moe_w4a8_autotunecan simply call this helper with the appropriate flag and their respectivepytestparameterizations.
There was a problem hiding this comment.
Again, just let me know your preference and I'll update.
There was a problem hiding this comment.
Could you adopt the suggestion in (2.)? In the original test_moe_w4a8 test, we can add an additional parameter autotune and when autotune=True we can wrap the call with the autotune context.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1659-1661: Replace the direct torch.cuda.get_device_capability()
check with the project's helper(s): import and call is_sm90a_supported() (or use
get_compute_capability() if you need major/minor) and use pytest.skip when
unsupported; specifically change the conditional that currently uses
torch.cuda.get_device_capability()[0] != 9 to use is_sm90a_supported() (or check
get_compute_capability()) and keep the existing pytest.skip message, leaving the
subsequent top_k > num_experts logic unchanged.
| if torch.cuda.get_device_capability()[0] != 9: | ||
| pytest.skip("W4A8 is only supported on SM90") | ||
| if top_k > num_experts: |
There was a problem hiding this comment.
Use flashinfer.utils capability helpers for SM90 skips.
The direct torch.cuda.get_device_capability() check should use the project’s helper(s) for consistency and portability.
🔧 Proposed fix
- if torch.cuda.get_device_capability()[0] != 9:
+ if get_compute_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90") from flashinfer import (
autotune,
fp4_quantize,
mxfp4_dequantize,
mxfp4_quantize,
mxfp8_dequantize_host,
mxfp8_quantize,
mxfp4_dequantize_host,
)
+from flashinfer.utils import get_compute_capabilityAs per coding guidelines, Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if torch.cuda.get_device_capability()[0] != 9: | |
| pytest.skip("W4A8 is only supported on SM90") | |
| if top_k > num_experts: | |
| if get_compute_capability()[0] != 9: | |
| pytest.skip("W4A8 is only supported on SM90") | |
| if top_k > num_experts: |
| if torch.cuda.get_device_capability()[0] != 9: | |
| pytest.skip("W4A8 is only supported on SM90") | |
| if top_k > num_experts: | |
| from flashinfer import ( | |
| autotune, | |
| fp4_quantize, | |
| mxfp4_dequantize, | |
| mxfp4_quantize, | |
| mxfp8_dequantize_host, | |
| mxfp8_quantize, | |
| mxfp4_dequantize_host, | |
| ) | |
| from flashinfer.utils import get_compute_capability |
🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_cutlass_fused_moe.py` around lines 1659 - 1661, Replace
the direct torch.cuda.get_device_capability() check with the project's
helper(s): import and call is_sm90a_supported() (or use get_compute_capability()
if you need major/minor) and use pytest.skip when unsupported; specifically
change the conditional that currently uses torch.cuda.get_device_capability()[0]
!= 9 to use is_sm90a_supported() (or check get_compute_capability()) and keep
the existing pytest.skip message, leaving the subsequent top_k > num_experts
logic unchanged.
There was a problem hiding this comment.
I copied the entire moe w4a8 test as-is and just wrapped it in the autotuner, if you'd like the underlying code can be encapsulated in a function, up to you.
There was a problem hiding this comment.
@ipnon, thanks for the context! You're absolutely right that copying the test as-is makes sense for consistency — I can see the entire file uses torch.cuda.get_device_capability() directly throughout (lines 808, 990, 1223, 1290, 1549, etc.).
My suggestion was based on the coding guideline about using flashinfer.utils helpers like get_compute_capability(), but since this pattern exists file-wide, changing just the new test would create inconsistency. Fixing all occurrences would be out of scope for this bug fix PR.
Feel free to keep the current approach for consistency with the rest of the file. As for refactoring the shared logic into a function — that's a good idea but probably better suited for a separate cleanup PR rather than mixing it with this bug fix.
🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-02-02T07:46:39.459Z
Learning: Applies to tests/**/*.py : Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures
Learnt from: yzh119
Repo: flashinfer-ai/flashinfer PR: 2370
File: tests/gdn/conftest.py:25-34
Timestamp: 2026-01-21T21:26:07.831Z
Learning: FlashInfer does not support CPU-only environments. Tests assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures.
|
This if my first PR here, I attempted to thoroughly review the problem and steelman any potential review issues, let me know what you think, thanks. |
jimmyzho
left a comment
There was a problem hiding this comment.
Looking good, just a small nit in the test. Thanks for the fix!
| def test_moe_w4a8_autotune( | ||
| batch_size: int, | ||
| hidden_size: int, | ||
| num_experts: int, | ||
| top_k: int, | ||
| intermediate_size: int, | ||
| dtype: torch.dtype, | ||
| ): | ||
| """Test MoE with W4A8 quantization and autotuning enabled (regression test for #2501). | ||
|
|
||
| Identical to test_moe_w4a8 except: | ||
| - Kernel call is wrapped in autotune(True) to exercise the profiling path | ||
| - Smaller parameter set since autotuning is slower | ||
| """ | ||
| if torch.cuda.get_device_capability()[0] != 9: | ||
| pytest.skip("W4A8 is only supported on SM90") | ||
| if top_k > num_experts: | ||
| pytest.skip("top_k must be <= num_experts") | ||
|
|
||
| torch.manual_seed(42) | ||
| group_size = 128 | ||
| e = num_experts | ||
| m = batch_size | ||
| n = intermediate_size | ||
| k = hidden_size | ||
| affine_coeff = 0.005 | ||
|
|
||
| x = torch.randn(m, k, dtype=dtype, device="cuda") | ||
| router_logits = torch.randn(m, e, dtype=dtype, device="cuda") | ||
| w1_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda") | ||
| w2_weight = torch.randint(0, 256, (e, k, n // 2), dtype=torch.uint8, device="cuda") | ||
| w3_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda") | ||
|
|
||
| w1_scale = ( | ||
| torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff | ||
| ) | ||
| w2_scale = ( | ||
| torch.randn(e, k, n // group_size, dtype=dtype, device="cuda") * affine_coeff | ||
| ) | ||
| w3_scale = ( | ||
| torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff | ||
| ) | ||
|
|
||
| w1_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||
| w2_pre_quant_scale = torch.rand(e, n, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||
| w3_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||
|
|
||
| input_scale = torch.rand(e, 1, dtype=torch.float32, device="cuda") * 0.2 + 0.1 | ||
| weight_scale_2 = torch.ones(e, 1, dtype=torch.float32, device="cuda") | ||
|
|
||
| fc1_weights = torch.cat([w3_weight, w1_weight], dim=1) | ||
| fc2_weights = w2_weight | ||
|
|
||
| def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor: | ||
| interleave_factor = 4 if dim % 512 == 0 else (2 if dim % 256 == 0 else 1) | ||
| s = w.shape | ||
| w_interleaved = ( | ||
| w.reshape(s[0], s[1], s[2] // interleave_factor, interleave_factor) | ||
| .permute(0, 2, 1, 3) | ||
| .reshape(s[0], s[2] // interleave_factor, s[1] * interleave_factor) | ||
| .contiguous() | ||
| ) | ||
| return w_interleaved | ||
|
|
||
| w3_w1_scales = torch.cat([w3_scale, w1_scale], dim=1) | ||
| w3_w1_scales_int = interleave_weights(w3_w1_scales, k) | ||
| w2_scales_int = interleave_weights(w2_scale, n) | ||
|
|
||
| w3_w1_pre_quant_max = torch.max(w1_pre_quant_scale, w3_pre_quant_scale) | ||
| w3_w1_input_scale_max = input_scale.max() | ||
| fc31_act_scale = (w3_w1_pre_quant_max / w3_w1_input_scale_max).to(dtype) | ||
| fc2_act_scale = (w2_pre_quant_scale / input_scale).to(dtype).unsqueeze(-1) | ||
|
|
||
| fc31_alpha = (weight_scale_2.squeeze(-1) * w3_w1_input_scale_max).float() | ||
| fc2_alpha = (weight_scale_2.squeeze(-1) * input_scale.squeeze(-1)).float() | ||
|
|
||
| zero_1 = torch.empty(0, dtype=dtype, device="cuda") | ||
| zero_2 = torch.empty(0, dtype=dtype, device="cuda") | ||
|
|
||
| sm = ( | ||
| torch.cuda.get_device_capability()[0] * 10 | ||
| + torch.cuda.get_device_capability()[1] | ||
| ) | ||
| if sm >= 90: | ||
| w3_w1_scales_out = w3_w1_scales_int.to(torch.bfloat16).view(dtype) | ||
| w2_scales_out = w2_scales_int.to(torch.bfloat16).view(dtype) | ||
| fc31_act_out = fc31_act_scale.to(torch.bfloat16).view(dtype) | ||
| fc2_act_out = fc2_act_scale.to(torch.bfloat16).view(dtype) | ||
| else: | ||
| w3_w1_scales_out = w3_w1_scales_int.to(dtype) | ||
| w2_scales_out = w2_scales_int.to(dtype) | ||
| fc31_act_out = fc31_act_scale | ||
| fc2_act_out = fc2_act_scale | ||
|
|
||
| quant_scales = ( | ||
| w3_w1_scales_out, | ||
| w2_scales_out, | ||
| fc31_act_out, | ||
| fc2_act_out, | ||
| zero_1, | ||
| zero_2, | ||
| fc31_alpha, | ||
| fc2_alpha, | ||
| ) | ||
|
|
||
| routing_weights, selected_experts = compute_routing(router_logits, top_k) | ||
| selected_experts_int32 = selected_experts.to(torch.int32) | ||
|
|
||
| flash_output = torch.zeros_like(x) | ||
| with autotune(True): | ||
| _ = fused_moe.cutlass_fused_moe( | ||
| x, | ||
| selected_experts_int32, | ||
| routing_weights, | ||
| fc1_weights.view(torch.uint8), | ||
| fc2_weights.view(torch.uint8), | ||
| dtype, | ||
| quant_scales=quant_scales, | ||
| use_w4_group_scaling=True, | ||
| output=flash_output, | ||
| use_packed_weights=True, | ||
| ) | ||
|
|
||
| w31_weight_list = [] | ||
| w2_weight_list = [] | ||
|
|
||
| for e_idx in range(num_experts): | ||
| w1_w = w1_weight[e_idx] | ||
| w3_w = w3_weight[e_idx] | ||
| w2_w = w2_weight[e_idx] | ||
| w1_s = w1_scale[e_idx] | ||
| w3_s = w3_scale[e_idx] | ||
| w2_s = w2_scale[e_idx] | ||
| ws2 = weight_scale_2[e_idx] | ||
|
|
||
| w1_dequant = dequantize_int4_to_dtype(w1_w, w1_s, group_size, dtype, ws2) | ||
| w3_dequant = dequantize_int4_to_dtype(w3_w, w3_s, group_size, dtype, ws2) | ||
| w2_dequant = dequantize_int4_to_dtype(w2_w, w2_s, group_size, dtype, ws2) | ||
|
|
||
| w31 = torch.cat([w3_dequant, w1_dequant], dim=0) | ||
|
|
||
| w31_weight_list.append(w31) | ||
| w2_weight_list.append(w2_dequant) | ||
|
|
||
| w31_weight_dequant = torch.stack(w31_weight_list, dim=0) | ||
| w2_weight_dequant = torch.stack(w2_weight_list, dim=0) | ||
|
|
||
| ref_output = torch_moe_w4a8( | ||
| num_experts, | ||
| x, | ||
| w31_weight_dequant, | ||
| w2_weight_dequant, | ||
| selected_experts, | ||
| routing_weights, | ||
| fc1_input_scale=input_scale.squeeze(-1), | ||
| fc2_input_scale=input_scale.squeeze(-1), | ||
| fc1_pre_quant_scale=torch.max(w1_pre_quant_scale, w3_pre_quant_scale), | ||
| fc2_pre_quant_scale=w2_pre_quant_scale, | ||
| fc1_weight_scale_2=weight_scale_2.squeeze(-1), | ||
| fc2_weight_scale_2=weight_scale_2.squeeze(-1), | ||
| ) | ||
| torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1) |
There was a problem hiding this comment.
Could you adopt the suggestion in (2.)? In the original test_moe_w4a8 test, we can add an additional parameter autotune and when autotune=True we can wrap the call with the autotune context.
Address review feedback from jimmyzho: instead of a separate test_moe_w4a8_autotune function duplicating 170 lines, add a use_autotune parametrize to the existing test and conditionally wrap the kernel call with autotune(True). Co-Authored-By: Claude Opus 4.6 <[email protected]>
There was a problem hiding this comment.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@tests/moe/test_trtllm_cutlass_fused_moe.py`:
- Around line 1481-1482: This duplicate direct call to
torch.cuda.get_device_capability() (the if block that calls pytest.skip("W4A8 is
only supported on SM90")) should be made consistent with the rest of the file:
remove this isolated check and replace it with the shared helper or predicate
used elsewhere (e.g., use the existing is_sm90()/skip_if_not_sm90() helper or
the common capability-check utility used in other tests) so all tests perform
the SM90 guard via the same function; ensure you import or reference that helper
and replace occurrences of torch.cuda.get_device_capability()[0] in this test
with that helper and keep the pytest.skip call only inside the centralized
guard.
|
Done — folded Verified on H100 (SM90) — 4/4 passed (bf16 + fp16, with and without autotune): |
|
/bot run |
|
[FAILED] Pipeline #44270384: 15/20 passed |
Description
Problem
cutlass_fused_moecrashes withRuntimeError: Assertion failed: quant_1 && quant_2when called insideautotune()with W4A8 quantization. Other quant modes (FP8 block scaling, BF16+MXFP4) are unaffected.Root Cause
getProfilerWorkspaces()incutlass_fused_moe_kernels.cuhdoes not recognizekUINT8as an integer-quantized weight type. W4A8 useskUINT8(packed INT4 pairs), but the three type-check sites only matchkINT8andkINT4:is_int_w_quant— controls per-tensor scale buffer allocationis_int_groupwise_w_quant— controls groupwise scale buffer allocationdtype_bytesternary — controls scale factor byte widthWith
kUINT8unrecognized, all three evaluate false, so the profiler allocates zero-size quant buffers. WhenprepareQuantParams()later tries to use them, it asserts on the resulting nullptrs.The consumer side (
prepareQuantParams) already handleskUINT8correctly — only the allocator side was missing it. Notably,is_wfp4a16_quantin the same function does check forkUINT8, but only when activations are HALF/BF16 — not FP8. So the workspace code was partially aware ofkUINT8; the gap was specifically in the FP8 activation pairing that W4A8 uses.Fix
Add
kUINT8to all three sites ingetProfilerWorkspaces(), matching the existingkINT4behavior (sincekUINT8is the packed representation of INT4 weights).This cascades correctly through the existing logic:
is_int_groupwise_w_quantbecomes true (for groupwise path)is_w4afp8_quant = is_int_groupwise_w_quant && is_fp8_act_quantbecomes truew4a8_alpha_size = num_experts_per_node * sizeof(float)becomes non-zeroADD(w4a8_alpha)registers the buffer in the workspace mapGET_WS_PTR(float const*, w4a8_alpha)inprepareQuantParams()returns a valid pointer instead of nullptrTesting
Added
test_moe_w4a8_autotuneregression test — identical totest_moe_w4a8but wrapped inautotune(True)with a reduced parameter set (autotuning is slow due to JIT compilation of 163 SM90 CUTLASS kernels).Tested on H100 (SM90): 2 passed in 25:14.
Related Issues
Fixes #2501
Pull Request Checklist
Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.Tests
unittest, etc.).Reviewer Notes
Root cause class: allocator-consumer mismatch from missing case in non-exhaustive type dispatch.
prepareQuantParams()was updated for W4A8 support butgetProfilerWorkspaces()was not — the two sites use independent boolean checks rather than a shared dispatch table, so adding a new weight type requires updating both manually.Summary by CodeRabbit
New Features
Tests