Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
use_dp: bool = False,
):
super().__init__(quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
Expand All @@ -67,6 +68,7 @@ def __init__(
self.tp_rank = tp_rank
self.tp_size = tp_size
self.out_dtype = out_dtype
self.use_dp = use_dp

@property
def activation_formats(
Expand Down Expand Up @@ -117,7 +119,8 @@ def workspace_shapes(
"""
workspace1 = (M, K)
workspace2 = (0,)
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K)
# For TP, the quantization is fused with fused_moe call.
output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" and self.use_dp else K)
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape)
Expand Down Expand Up @@ -214,6 +217,7 @@ def flashinfer_cutlass_moe_fp4(
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,
use_dp=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we always setting use_dp=False? Doesn't flashinfer_cutlass_moe_fp4 also support dp?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flashinfer_cutlass_moe_fp4(dead code) is removed in this PR. Previously flashinfer_cutlass_moe_fp4 is meant for TP case only. If DP, the fused expert is assemble elsewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is TP + Flashinfer cutlass handled now? Could we remove this method altogether or do we see any cases where it would be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vllm-project/vllm/pull/27223/files#diff-5bb9585da825481e1ae1534657a703f846325c7dd72cfddb0c41f878db33d78aR82 differentiate TP vs DP. But in either case, a fused moe expert is assembled. I agree that flashinfer_cutlass_moe_fp4 should be removed. But compressed tensor TP could still depends on it. So its removal is out-of-scope for this PR.

),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def prepare(
self._apply_router_weight_on_input(
a1, topk_weights, topk_ids, apply_router_weight_on_input
)
if not self.use_dp:
return a1, None, None, topk_ids, topk_weights

a1q, a1q_scale = moe_kernel_quantize_input(
a1,
Expand All @@ -179,14 +181,13 @@ def prepare(
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(),
)
if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
topk_weights, topk_ids, a1q, a1q_scale = get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(),
)
if quant_config.quant_dtype == "nvfp4":
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)

return a1q, a1q_scale, None, topk_ids, topk_weights

Expand Down
23 changes: 0 additions & 23 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,29 +1769,6 @@ def apply(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this break PP mode? when run deepseek PP, it will go to the last clause that use cutlass_moe_fp4, which will break on SM120. While on SGLang, flashinfer_cutlass works with sm120 PP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, #27123 added back this and solved my issue.

from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
flashinfer_cutlass_moe_fp4,
)

assert self.moe_quant_config is not None

return flashinfer_cutlass_moe_fp4(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update for compressed-tensors too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl @mgoin is my understanding correct that, going forward, direct calling into FlashInfer is going to be deprecated in favour of modular kernels, for all cases (e.g., regardless of EP/TP/DP choices)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl @mgoin is my understanding correct that, going forward, direct calling into FlashInfer is going to be deprecated in favour of modular kernels, for all cases (e.g., regardless of EP/TP/DP choices)?

There's no plan to force all MoE kernels to be called via modular kernels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But by deleting this elif clause (and, per @mgoin 's suggestion, applying this change to compressed-tensors), doesn't it force the use of FlashInfer cutlass implementation to go through the modular kernels?

I'm just trying to understand if this is the plan for all cases that use FlashInfer, regardless of distributed strategies, or whether self.flashinfer_moe_backend is FlashinferMoeBackend.TENSORRT_LLM or FlashinferMoeBackend.CUTLASS.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any reason to use the modular kernels for cases that aren't using some kind of all2all communication. In this particular case I think @wenscarl and @leejnau figured out that this was dead code because the CUTLASS case always created a modular kernel. I'm not sure if the same holds true for compressed_tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnellnm this PR is updated with a additional quant_dtype: nvfp4_skip_quantization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just trying to understand if this is the plan for all cases that use FlashInfer
I vote for that. Since flashinfer cutlass moe is at least a better option to normal cutlass_moe. TRTLLM MoE can win sometimes even.

hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_config=self.moe_quant_config,
inplace=False,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def select_nvfp4_gemm_impl(
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
use_dp=moe.moe_parallel_config.dp_size > 1,
)

# native cutlass experts currently don't support DP; TP case won't call this
Expand Down