33
44import torch
55
6- from vllm .distributed import tensor_model_parallel_all_reduce
6+ from vllm .distributed import (
7+ get_tensor_model_parallel_world_size ,
8+ tensor_model_parallel_all_reduce ,
9+ )
710from vllm .model_executor .layers .fused_moe .layer import FusedMoE
811
912
@@ -25,16 +28,13 @@ def __init__(
2528 super ().__init__ (** kwargs )
2629 self ._shared_experts = shared_experts
2730
28- # Disable shared expert overlap if EP is disabled or we are not using
31+ # Disable shared expert overlap if we are not using
2932 # flashinfer + DP since there is nothing to be gained in this case.
3033 # Disabling the overlap optimization also prevents the shared experts
3134 # from being hidden from torch.compile.
3235 self .use_overlapped = (
3336 use_overlapped
34- and not (
35- self .use_ep
36- or (self .use_flashinfer_cutlass_kernels and self .dp_size > 1 )
37- )
37+ and not (self .use_flashinfer_cutlass_kernels and self .dp_size > 1 )
3838 and self ._shared_experts is not None
3939 )
4040
@@ -65,7 +65,7 @@ def forward(
6565 # should have been created with reduce_results=False.
6666 if (
6767 self .reduce_results
68- and self . tp_size > 1
68+ and get_tensor_model_parallel_world_size () > 1
6969 and self .must_reduce_shared_expert_outputs ()
7070 ):
7171 shared_out = tensor_model_parallel_all_reduce (shared_out )
@@ -81,4 +81,12 @@ def forward(
8181 hidden_states = hidden_states ,
8282 router_logits = router_logits ,
8383 )
84+ # ensure early TP reduction of shared expert outputs when required
85+ if (
86+ shared_out is not None
87+ and self .reduce_results
88+ and get_tensor_model_parallel_world_size () > 1
89+ and self .must_reduce_shared_expert_outputs ()
90+ ):
91+ shared_out = tensor_model_parallel_all_reduce (shared_out )
8492 return shared_out , fused_out
0 commit comments