Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
per_act_token: bool,
per_act_token: Optional[bool] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a docstring to explain the purpose of the per_act_token parameter, especially now that it's optional. This will help users understand when and how to use this parameter.

per_act_token: Optional[bool] = None,  """Whether to use a different scale for each token."""

activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -366,6 +366,9 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
if per_act_token is None:
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
Comment on lines +369 to +371
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested ternary expression can be difficult to read. Refactoring this into a more explicit if/elif/else block would improve code clarity and maintainability.

    if per_act_token is None:
        if a1_scale is not None:
            per_act_token = a1_scale.numel() != 1
        elif a2_scale is not None:
            per_act_token = a2_scale.numel() != 1
        else:
            per_act_token = False

per_out_ch = w1_scale.numel() != w1_q.size(0)

num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
Expand Down