Skip to content

Commit 24e6691

Browse files
wenscarlzhaozuy
authored andcommitted
Flashinfer_CUTLASS_MOE fuses quantization for TP (vllm-project#27223)
Signed-off-by: Shu Wang. <[email protected]>
1 parent 04707d2 commit 24e6691

File tree

4 files changed

+15
-32
lines changed

4 files changed

+15
-32
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(
5656
ep_size: int = 1,
5757
tp_rank: int = 0,
5858
tp_size: int = 1,
59+
use_dp: bool = False,
5960
):
6061
super().__init__(quant_config)
6162
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
@@ -67,6 +68,7 @@ def __init__(
6768
self.tp_rank = tp_rank
6869
self.tp_size = tp_size
6970
self.out_dtype = out_dtype
71+
self.use_dp = use_dp
7072

7173
@property
7274
def activation_formats(
@@ -117,7 +119,8 @@ def workspace_shapes(
117119
"""
118120
workspace1 = (M, K)
119121
workspace2 = (0,)
120-
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
122+
# For TP, the quantization is fused with fused_moe call.
123+
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
121124
# The workspace is determined by `aq`, since it comes after any
122125
# potential communication op and is involved in the expert computation.
123126
return (workspace1, workspace2, output_shape)
@@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
214217
FlashInferExperts(
215218
out_dtype=hidden_states.dtype,
216219
quant_config=quant_config,
220+
use_dp=False,
217221
),
218222
)
219223

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def prepare(
170170
self._apply_router_weight_on_input(
171171
a1, topk_weights, topk_ids, apply_router_weight_on_input
172172
)
173+
if not self.use_dp:
174+
return a1, None, None, topk_ids, topk_weights
173175

174176
a1q, a1q_scale = moe_kernel_quantize_input(
175177
a1,
@@ -179,14 +181,13 @@ def prepare(
179181
quant_config.block_shape,
180182
is_fp4_scale_swizzled=not self.use_dp,
181183
)
182-
if self.use_dp:
183-
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
184-
[topk_weights, topk_ids, a1q, a1q_scale],
185-
dim=0,
186-
sizes=get_local_sizes(),
187-
)
188-
if quant_config.quant_dtype == "nvfp4":
189-
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
184+
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
185+
[topk_weights, topk_ids, a1q, a1q_scale],
186+
dim=0,
187+
sizes=get_local_sizes(),
188+
)
189+
if quant_config.quant_dtype == "nvfp4":
190+
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
190191

191192
return a1q, a1q_scale, None, topk_ids, topk_weights
192193

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,29 +1769,6 @@ def apply(
17691769
expert_map=expert_map,
17701770
apply_router_weight_on_input=apply_router_weight_on_input,
17711771
)
1772-
elif (
1773-
self.allow_flashinfer
1774-
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
1775-
):
1776-
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
1777-
flashinfer_cutlass_moe_fp4,
1778-
)
1779-
1780-
assert self.moe_quant_config is not None
1781-
1782-
return flashinfer_cutlass_moe_fp4(
1783-
hidden_states=x,
1784-
w1=layer.w13_weight,
1785-
w2=layer.w2_weight,
1786-
topk_weights=topk_weights,
1787-
topk_ids=topk_ids,
1788-
quant_config=self.moe_quant_config,
1789-
inplace=False,
1790-
activation=activation,
1791-
global_num_experts=global_num_experts,
1792-
expert_map=expert_map,
1793-
apply_router_weight_on_input=apply_router_weight_on_input,
1794-
)
17951772
else:
17961773
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
17971774
# only (no EP).

vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
7979
ep_size=moe.moe_parallel_config.ep_size,
8080
tp_rank=moe.moe_parallel_config.tp_rank,
8181
tp_size=moe.moe_parallel_config.tp_size,
82+
use_dp=moe.moe_parallel_config.dp_size > 1,
8283
)
8384

8485
# native cutlass experts currently don't support DP; TP case won't call this

0 commit comments

Comments
 (0)