diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index b7820319682b..85ce77fb1f7f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -56,6 +56,7 @@ def __init__( ep_size: int = 1, tp_rank: int = 0, tp_size: int = 1, + use_dp: bool = False, ): super().__init__(quant_config) assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), ( @@ -67,6 +68,7 @@ def __init__( self.tp_rank = tp_rank self.tp_size = tp_size self.out_dtype = out_dtype + self.use_dp = use_dp @property def activation_formats( @@ -117,7 +119,8 @@ def workspace_shapes( """ workspace1 = (M, K) workspace2 = (0,) - output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K) + # For TP, the quantization is fused with fused_moe call. + output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K) # The workspace is determined by `aq`, since it comes after any # potential communication op and is involved in the expert computation. return (workspace1, workspace2, output_shape) @@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4( FlashInferExperts( out_dtype=hidden_states.dtype, quant_config=quant_config, + use_dp=False, ), ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 20e2f6c85186..051abbcb7949 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -170,6 +170,8 @@ def prepare( self._apply_router_weight_on_input( a1, topk_weights, topk_ids, apply_router_weight_on_input ) + if not self.use_dp: + return a1, None, None, topk_ids, topk_weights a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -179,14 +181,13 @@ def prepare( quant_config.block_shape, is_fp4_scale_swizzled=not self.use_dp, ) - if self.use_dp: - topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( - [topk_weights, topk_ids, a1q, a1q_scale], - dim=0, - sizes=get_local_sizes(), - ) - if quant_config.quant_dtype == "nvfp4": - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv( + [topk_weights, topk_ids, a1q, a1q_scale], + dim=0, + sizes=get_local_sizes(), + ) + if quant_config.quant_dtype == "nvfp4": + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0eeeaa3ce457..37b682984fc3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1769,29 +1769,6 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - ): - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4, - ) - - assert self.moe_quant_config is not None - - return flashinfer_cutlass_moe_fp4( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index b3a4cb2de139..fdf330329e20 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl( ep_size=moe.moe_parallel_config.ep_size, tp_rank=moe.moe_parallel_config.tp_rank, tp_size=moe.moe_parallel_config.tp_size, + use_dp=moe.moe_parallel_config.dp_size > 1, ) # native cutlass experts currently don't support DP; TP case won't call this