Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ class FusedMoE(CustomOp):
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
reduce_results: Whether to all_reduce on the output of the layer
renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
Expand Down
15 changes: 10 additions & 5 deletions vllm/model_executor/layers/fused_moe/shared_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@ def __init__(
super().__init__(**kwargs)
self._shared_experts = shared_experts

# Disable shared expert overlap if EP is disabled or we are not using
# Disable shared expert overlap if we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
and self._shared_experts is not None
)

Expand Down Expand Up @@ -81,4 +78,12 @@ def forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
and self.reduce_results
and self.tp_size > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out