Skip to content

Commit a4a7418

Browse files
yewentao256devpatelio
authored andcommitted
[Feature] Enable TP + EP shared_experts overlap with router, 3.7% E2E performance improvement (vllm-project#28164)
Signed-off-by: yewentao256 <[email protected]>
1 parent 4e88b99 commit a4a7418

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,7 @@ class FusedMoE(CustomOp):
11781178
hidden_size: Input hidden state size of the transformer
11791179
intermediate_size: Intermediate size of the experts
11801180
params_dtype: Data type for the parameters.
1181-
reduce_results: Whether to all all_reduce on the output of the layer
1181+
reduce_results: Whether to all_reduce on the output of the layer
11821182
renormalize: Whether to renormalize the logits in the fused_moe kernel
11831183
quant_config: Quantization configure.
11841184
enable_eplb: Whether to enable expert parallelism load balancer.

vllm/model_executor/layers/fused_moe/shared_fused_moe.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
import torch
55

6-
from vllm.distributed import tensor_model_parallel_all_reduce
6+
from vllm.distributed import (
7+
get_tensor_model_parallel_world_size,
8+
tensor_model_parallel_all_reduce,
9+
)
710
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
811

912

@@ -25,16 +28,13 @@ def __init__(
2528
super().__init__(**kwargs)
2629
self._shared_experts = shared_experts
2730

28-
# Disable shared expert overlap if EP is disabled or we are not using
31+
# Disable shared expert overlap if we are not using
2932
# flashinfer + DP since there is nothing to be gained in this case.
3033
# Disabling the overlap optimization also prevents the shared experts
3134
# from being hidden from torch.compile.
3235
self.use_overlapped = (
3336
use_overlapped
34-
and not (
35-
self.use_ep
36-
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
37-
)
37+
and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
3838
and self._shared_experts is not None
3939
)
4040

@@ -65,7 +65,7 @@ def forward(
6565
# should have been created with reduce_results=False.
6666
if (
6767
self.reduce_results
68-
and self.tp_size > 1
68+
and get_tensor_model_parallel_world_size() > 1
6969
and self.must_reduce_shared_expert_outputs()
7070
):
7171
shared_out = tensor_model_parallel_all_reduce(shared_out)
@@ -81,4 +81,12 @@ def forward(
8181
hidden_states=hidden_states,
8282
router_logits=router_logits,
8383
)
84+
# ensure early TP reduction of shared expert outputs when required
85+
if (
86+
shared_out is not None
87+
and self.reduce_results
88+
and get_tensor_model_parallel_world_size() > 1
89+
and self.must_reduce_shared_expert_outputs()
90+
):
91+
shared_out = tensor_model_parallel_all_reduce(shared_out)
8492
return shared_out, fused_out

0 commit comments

Comments
 (0)