-
Notifications
You must be signed in to change notification settings - Fork 973
fix: W4A8 autotune crash in cutlass_fused_moe profiler workspace #2564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import flashinfer.fused_moe as fused_moe | ||||||||||||||||||||||||||||||||||||||||
| from flashinfer import ( | ||||||||||||||||||||||||||||||||||||||||
| autotune, | ||||||||||||||||||||||||||||||||||||||||
| fp4_quantize, | ||||||||||||||||||||||||||||||||||||||||
| mxfp4_dequantize, | ||||||||||||||||||||||||||||||||||||||||
| mxfp4_quantize, | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1632,5 +1633,178 @@ def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor: | |||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # Reduced parameter set vs test_moe_w4a8: autotuning is slow, so we test | ||||||||||||||||||||||||||||||||||||||||
| # fewer combinations. The goal is to verify the profiler workspace allocation | ||||||||||||||||||||||||||||||||||||||||
| # doesn't crash, not to sweep all shapes. See: github.com/flashinfer-ai/flashinfer/issues/2501 | ||||||||||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("batch_size", [1, 16]) | ||||||||||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("hidden_size", [256]) | ||||||||||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("num_experts", [8]) | ||||||||||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("top_k", [2]) | ||||||||||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("intermediate_size", [512]) | ||||||||||||||||||||||||||||||||||||||||
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||||||||||||||||||||||||||||||||||||||||
| def test_moe_w4a8_autotune( | ||||||||||||||||||||||||||||||||||||||||
| batch_size: int, | ||||||||||||||||||||||||||||||||||||||||
| hidden_size: int, | ||||||||||||||||||||||||||||||||||||||||
| num_experts: int, | ||||||||||||||||||||||||||||||||||||||||
| top_k: int, | ||||||||||||||||||||||||||||||||||||||||
| intermediate_size: int, | ||||||||||||||||||||||||||||||||||||||||
| dtype: torch.dtype, | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
| """Test MoE with W4A8 quantization and autotuning enabled (regression test for #2501). | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| Identical to test_moe_w4a8 except: | ||||||||||||||||||||||||||||||||||||||||
| - Kernel call is wrapped in autotune(True) to exercise the profiling path | ||||||||||||||||||||||||||||||||||||||||
| - Smaller parameter set since autotuning is slower | ||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||
| if torch.cuda.get_device_capability()[0] != 9: | ||||||||||||||||||||||||||||||||||||||||
| pytest.skip("W4A8 is only supported on SM90") | ||||||||||||||||||||||||||||||||||||||||
| if top_k > num_experts: | ||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use flashinfer.utils capability helpers for SM90 skips. 🔧 Proposed fix- if torch.cuda.get_device_capability()[0] != 9:
+ if get_compute_capability()[0] != 9:
pytest.skip("W4A8 is only supported on SM90") from flashinfer import (
autotune,
fp4_quantize,
mxfp4_dequantize,
mxfp4_quantize,
mxfp8_dequantize_host,
mxfp8_quantize,
mxfp4_dequantize_host,
)
+from flashinfer.utils import get_compute_capabilityAs per coding guidelines, Test files must use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, etc.) to skip tests on unsupported GPU architectures. 📝 Committable suggestion
Suggested change
Suggested change
🤖 Prompt for AI Agents
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I copied the entire moe w4a8 test as-is and just wrapped it in the autotuner, if you'd like the underlying code can be encapsulated in a function, up to you.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
My suggestion was based on the coding guideline about using Feel free to keep the current approach for consistency with the rest of the file. As for refactoring the shared logic into a function — that's a good idea but probably better suited for a separate cleanup PR rather than mixing it with this bug fix. 🧠 Learnings used |
||||||||||||||||||||||||||||||||||||||||
| pytest.skip("top_k must be <= num_experts") | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| torch.manual_seed(42) | ||||||||||||||||||||||||||||||||||||||||
| group_size = 128 | ||||||||||||||||||||||||||||||||||||||||
| e = num_experts | ||||||||||||||||||||||||||||||||||||||||
| m = batch_size | ||||||||||||||||||||||||||||||||||||||||
| n = intermediate_size | ||||||||||||||||||||||||||||||||||||||||
| k = hidden_size | ||||||||||||||||||||||||||||||||||||||||
| affine_coeff = 0.005 | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| x = torch.randn(m, k, dtype=dtype, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
| router_logits = torch.randn(m, e, dtype=dtype, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
| w1_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
| w2_weight = torch.randint(0, 256, (e, k, n // 2), dtype=torch.uint8, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
| w3_weight = torch.randint(0, 256, (e, n, k // 2), dtype=torch.uint8, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w1_scale = ( | ||||||||||||||||||||||||||||||||||||||||
| torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| w2_scale = ( | ||||||||||||||||||||||||||||||||||||||||
| torch.randn(e, k, n // group_size, dtype=dtype, device="cuda") * affine_coeff | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| w3_scale = ( | ||||||||||||||||||||||||||||||||||||||||
| torch.randn(e, n, k // group_size, dtype=dtype, device="cuda") * affine_coeff | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w1_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||||||||||||||||||||||||||||||||||||||||
| w2_pre_quant_scale = torch.rand(e, n, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||||||||||||||||||||||||||||||||||||||||
| w3_pre_quant_scale = torch.rand(e, k, dtype=dtype, device="cuda") * 0.1 + 0.95 | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| input_scale = torch.rand(e, 1, dtype=torch.float32, device="cuda") * 0.2 + 0.1 | ||||||||||||||||||||||||||||||||||||||||
| weight_scale_2 = torch.ones(e, 1, dtype=torch.float32, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| fc1_weights = torch.cat([w3_weight, w1_weight], dim=1) | ||||||||||||||||||||||||||||||||||||||||
| fc2_weights = w2_weight | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def interleave_weights(w: torch.Tensor, dim: int) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||
| interleave_factor = 4 if dim % 512 == 0 else (2 if dim % 256 == 0 else 1) | ||||||||||||||||||||||||||||||||||||||||
| s = w.shape | ||||||||||||||||||||||||||||||||||||||||
| w_interleaved = ( | ||||||||||||||||||||||||||||||||||||||||
| w.reshape(s[0], s[1], s[2] // interleave_factor, interleave_factor) | ||||||||||||||||||||||||||||||||||||||||
| .permute(0, 2, 1, 3) | ||||||||||||||||||||||||||||||||||||||||
| .reshape(s[0], s[2] // interleave_factor, s[1] * interleave_factor) | ||||||||||||||||||||||||||||||||||||||||
| .contiguous() | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| return w_interleaved | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w3_w1_scales = torch.cat([w3_scale, w1_scale], dim=1) | ||||||||||||||||||||||||||||||||||||||||
| w3_w1_scales_int = interleave_weights(w3_w1_scales, k) | ||||||||||||||||||||||||||||||||||||||||
| w2_scales_int = interleave_weights(w2_scale, n) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w3_w1_pre_quant_max = torch.max(w1_pre_quant_scale, w3_pre_quant_scale) | ||||||||||||||||||||||||||||||||||||||||
| w3_w1_input_scale_max = input_scale.max() | ||||||||||||||||||||||||||||||||||||||||
| fc31_act_scale = (w3_w1_pre_quant_max / w3_w1_input_scale_max).to(dtype) | ||||||||||||||||||||||||||||||||||||||||
| fc2_act_scale = (w2_pre_quant_scale / input_scale).to(dtype).unsqueeze(-1) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| fc31_alpha = (weight_scale_2.squeeze(-1) * w3_w1_input_scale_max).float() | ||||||||||||||||||||||||||||||||||||||||
| fc2_alpha = (weight_scale_2.squeeze(-1) * input_scale.squeeze(-1)).float() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| zero_1 = torch.empty(0, dtype=dtype, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
| zero_2 = torch.empty(0, dtype=dtype, device="cuda") | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| sm = ( | ||||||||||||||||||||||||||||||||||||||||
| torch.cuda.get_device_capability()[0] * 10 | ||||||||||||||||||||||||||||||||||||||||
| + torch.cuda.get_device_capability()[1] | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| if sm >= 90: | ||||||||||||||||||||||||||||||||||||||||
| w3_w1_scales_out = w3_w1_scales_int.to(torch.bfloat16).view(dtype) | ||||||||||||||||||||||||||||||||||||||||
| w2_scales_out = w2_scales_int.to(torch.bfloat16).view(dtype) | ||||||||||||||||||||||||||||||||||||||||
| fc31_act_out = fc31_act_scale.to(torch.bfloat16).view(dtype) | ||||||||||||||||||||||||||||||||||||||||
| fc2_act_out = fc2_act_scale.to(torch.bfloat16).view(dtype) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| w3_w1_scales_out = w3_w1_scales_int.to(dtype) | ||||||||||||||||||||||||||||||||||||||||
| w2_scales_out = w2_scales_int.to(dtype) | ||||||||||||||||||||||||||||||||||||||||
| fc31_act_out = fc31_act_scale | ||||||||||||||||||||||||||||||||||||||||
| fc2_act_out = fc2_act_scale | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| quant_scales = ( | ||||||||||||||||||||||||||||||||||||||||
| w3_w1_scales_out, | ||||||||||||||||||||||||||||||||||||||||
| w2_scales_out, | ||||||||||||||||||||||||||||||||||||||||
| fc31_act_out, | ||||||||||||||||||||||||||||||||||||||||
| fc2_act_out, | ||||||||||||||||||||||||||||||||||||||||
| zero_1, | ||||||||||||||||||||||||||||||||||||||||
| zero_2, | ||||||||||||||||||||||||||||||||||||||||
| fc31_alpha, | ||||||||||||||||||||||||||||||||||||||||
| fc2_alpha, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| routing_weights, selected_experts = compute_routing(router_logits, top_k) | ||||||||||||||||||||||||||||||||||||||||
| selected_experts_int32 = selected_experts.to(torch.int32) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| flash_output = torch.zeros_like(x) | ||||||||||||||||||||||||||||||||||||||||
| with autotune(True): | ||||||||||||||||||||||||||||||||||||||||
| _ = fused_moe.cutlass_fused_moe( | ||||||||||||||||||||||||||||||||||||||||
| x, | ||||||||||||||||||||||||||||||||||||||||
| selected_experts_int32, | ||||||||||||||||||||||||||||||||||||||||
| routing_weights, | ||||||||||||||||||||||||||||||||||||||||
| fc1_weights.view(torch.uint8), | ||||||||||||||||||||||||||||||||||||||||
| fc2_weights.view(torch.uint8), | ||||||||||||||||||||||||||||||||||||||||
| dtype, | ||||||||||||||||||||||||||||||||||||||||
| quant_scales=quant_scales, | ||||||||||||||||||||||||||||||||||||||||
| use_w4_group_scaling=True, | ||||||||||||||||||||||||||||||||||||||||
| output=flash_output, | ||||||||||||||||||||||||||||||||||||||||
| use_packed_weights=True, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w31_weight_list = [] | ||||||||||||||||||||||||||||||||||||||||
| w2_weight_list = [] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| for e_idx in range(num_experts): | ||||||||||||||||||||||||||||||||||||||||
| w1_w = w1_weight[e_idx] | ||||||||||||||||||||||||||||||||||||||||
| w3_w = w3_weight[e_idx] | ||||||||||||||||||||||||||||||||||||||||
| w2_w = w2_weight[e_idx] | ||||||||||||||||||||||||||||||||||||||||
| w1_s = w1_scale[e_idx] | ||||||||||||||||||||||||||||||||||||||||
| w3_s = w3_scale[e_idx] | ||||||||||||||||||||||||||||||||||||||||
| w2_s = w2_scale[e_idx] | ||||||||||||||||||||||||||||||||||||||||
| ws2 = weight_scale_2[e_idx] | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w1_dequant = dequantize_int4_to_dtype(w1_w, w1_s, group_size, dtype, ws2) | ||||||||||||||||||||||||||||||||||||||||
| w3_dequant = dequantize_int4_to_dtype(w3_w, w3_s, group_size, dtype, ws2) | ||||||||||||||||||||||||||||||||||||||||
| w2_dequant = dequantize_int4_to_dtype(w2_w, w2_s, group_size, dtype, ws2) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w31 = torch.cat([w3_dequant, w1_dequant], dim=0) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w31_weight_list.append(w31) | ||||||||||||||||||||||||||||||||||||||||
| w2_weight_list.append(w2_dequant) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| w31_weight_dequant = torch.stack(w31_weight_list, dim=0) | ||||||||||||||||||||||||||||||||||||||||
| w2_weight_dequant = torch.stack(w2_weight_list, dim=0) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| ref_output = torch_moe_w4a8( | ||||||||||||||||||||||||||||||||||||||||
| num_experts, | ||||||||||||||||||||||||||||||||||||||||
| x, | ||||||||||||||||||||||||||||||||||||||||
| w31_weight_dequant, | ||||||||||||||||||||||||||||||||||||||||
| w2_weight_dequant, | ||||||||||||||||||||||||||||||||||||||||
| selected_experts, | ||||||||||||||||||||||||||||||||||||||||
| routing_weights, | ||||||||||||||||||||||||||||||||||||||||
| fc1_input_scale=input_scale.squeeze(-1), | ||||||||||||||||||||||||||||||||||||||||
| fc2_input_scale=input_scale.squeeze(-1), | ||||||||||||||||||||||||||||||||||||||||
| fc1_pre_quant_scale=torch.max(w1_pre_quant_scale, w3_pre_quant_scale), | ||||||||||||||||||||||||||||||||||||||||
| fc2_pre_quant_scale=w2_pre_quant_scale, | ||||||||||||||||||||||||||||||||||||||||
| fc1_weight_scale_2=weight_scale_2.squeeze(-1), | ||||||||||||||||||||||||||||||||||||||||
| fc2_weight_scale_2=weight_scale_2.squeeze(-1), | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-1) | ||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This new test function
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, just let me know your preference and I'll update.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you adopt the suggestion in (2.)? In the original |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||
| pytest.main([__file__, "-v"]) | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve readability and avoid repeating the same condition, you could introduce a boolean variable for the integer quantization type check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with implementing this if you want but I didn't want to override the established style too much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds reasonable
simplifying conditions and adding the immediate var names for readability is encouraged (tho we don't strictly impose the style). thx