diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index f985f9ac7ca6..707068b2bbdc 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -6,7 +6,10 @@ import torch from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -22,10 +25,10 @@ from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_capability( - 100 + 90 ): pytest.skip( - "Requires flashinfer_cutlass_fused_moe and nvfp4 support", + "Supported for sm >= 90", allow_module_level=True, ) @@ -131,6 +134,8 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( topk: int, monkeypatch, ): + if not current_platform.has_device_capability(100): + pytest.skip("Test is only supported for sm >= 100") current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): @@ -184,9 +189,6 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( torch.testing.assert_close(output, flashinfer_output, atol=5.5e-2, rtol=1e-2) -@pytest.mark.skip( - "Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472" -) @pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @@ -216,9 +218,13 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( quant_config = fp8_w8a8_moe_quant_config( w1_scale=td.w13_weight_scale, + g1_alphas=(td.w13_weight_scale * td.a1_scale).squeeze(), w2_scale=td.w2_weight_scale, + g2_alphas=(td.w2_weight_scale * td.a2_scale).squeeze(), a1_scale=td.a1_scale, + a1_gscale=td.a1_scale, a2_scale=td.a2_scale, + a2_gscale=1.0 / td.a2_scale, per_act_token_quant=False, ) @@ -238,6 +244,12 @@ def test_flashinfer_cutlass_moe_fp8_no_graph( td.layer.dp_size = 1 + def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: + return quant_config + + td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config + td.layer.quant_method = td.layer + flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( td.hidden_states, td.layer, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 239405332980..cbc3caafcf2f 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -463,6 +463,10 @@ def fp8_w8a8_moe_quant_config( per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: list[int] | None = None, + a1_gscale: torch.Tensor | None = None, + a2_gscale: torch.Tensor | None = None, + g1_alphas: torch.Tensor | None = None, + g2_alphas: torch.Tensor | None = None, ) -> FusedMoEQuantConfig: """ Construct a quant config for fp8 activations and fp8 weights. @@ -470,9 +474,13 @@ def fp8_w8a8_moe_quant_config( return FusedMoEQuantConfig.make( torch.float8_e4m3fn, w1_scale=w1_scale, + g1_alphas=g1_alphas, w2_scale=w2_scale, + g2_alphas=g2_alphas, a1_scale=a1_scale, + a1_gscale=a1_gscale, a2_scale=a2_scale, + a2_gscale=a2_gscale, per_act_token_quant=per_act_token_quant, per_out_ch_quant=per_out_ch_quant, block_shape=block_shape, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 051abbcb7949..97ee20ae9a11 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -170,7 +170,7 @@ def prepare( self._apply_router_weight_on_input( a1, topk_weights, topk_ids, apply_router_weight_on_input ) - if not self.use_dp: + if not self.use_dp and quant_config.quant_dtype == "nvfp4": return a1, None, None, topk_ids, topk_weights a1q, a1q_scale = moe_kernel_quantize_input( @@ -181,11 +181,13 @@ def prepare( quant_config.block_shape, is_fp4_scale_swizzled=not self.use_dp, ) - topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( - [topk_weights, topk_ids, a1q, a1q_scale], - dim=0, - sizes=get_local_sizes(), - ) + + if self.use_dp: + topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( + [topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes(), + ) if quant_config.quant_dtype == "nvfp4": a1q_scale = nvfp4_block_scale_interleave(a1q_scale) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 37b682984fc3..282274268571 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -570,9 +570,13 @@ def get_fused_moe_quant_config( return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, + g1_alphas=(layer.w13_weight_scale * layer.w13_input_scale).squeeze(), w2_scale=layer.w2_weight_scale, + g2_alphas=(layer.w2_weight_scale * layer.w2_input_scale).squeeze(), a1_scale=layer.w13_input_scale, + a1_gscale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, + a2_gscale=1.0 / layer.w2_input_scale, per_act_token_quant=False, ) @@ -1159,8 +1163,8 @@ def __init__( moe: FusedMoEConfig, layer: torch.nn.Module, ) -> None: - from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 - detect_nvfp4_moe_support, + from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( + detect_nvfp4_moe_support, # noqa: E501 ) super().__init__(moe)