Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ class FusedMoE(CustomOp):
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
reduce_results: Whether to all_reduce on the output of the layer
renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
Expand Down
22 changes: 15 additions & 7 deletions vllm/model_executor/layers/fused_moe/shared_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import torch

from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE


Expand All @@ -25,16 +28,13 @@ def __init__(
super().__init__(**kwargs)
self._shared_experts = shared_experts

# Disable shared expert overlap if EP is disabled or we are not using
# Disable shared expert overlap if we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
and self._shared_experts is not None
)

Expand Down Expand Up @@ -65,7 +65,7 @@ def forward(
# should have been created with reduce_results=False.
if (
self.reduce_results
and self.tp_size > 1
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
Expand All @@ -81,4 +81,12 @@ def forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
# ensure early TP reduction of shared expert outputs when required
if (
shared_out is not None
and self.reduce_results
and get_tensor_model_parallel_world_size() > 1
and self.must_reduce_shared_expert_outputs()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_out