Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9cf5b21
Simplified version of TP Attn + EP MoE perf bug fix
tlrmchlsmth Aug 31, 2025
46e4ad7
turn off chunking
tlrmchlsmth Sep 1, 2025
90a2a32
simplify tp padding
tlrmchlsmth Sep 2, 2025
85a1ff0
replicated linear shared expert
tlrmchlsmth Sep 2, 2025
08bace6
fixup moe chunking
tlrmchlsmth Sep 2, 2025
8174f70
cleanup and plumbing
tlrmchlsmth Sep 3, 2025
02fbb25
cleanup and comments
tlrmchlsmth Sep 3, 2025
76372f2
review comment
tlrmchlsmth Sep 3, 2025
3b6b3c7
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 3, 2025
249eb5f
comments
tlrmchlsmth Sep 3, 2025
3887a51
fixes for DS decoder layer constructor changes
tlrmchlsmth Sep 3, 2025
31e0c81
fixup
tlrmchlsmth Sep 3, 2025
c8fb93c
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 3, 2025
c574cab
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 3, 2025
090ae53
update chunking
tlrmchlsmth Sep 3, 2025
e9ee2c5
fixup
tlrmchlsmth Sep 3, 2025
56ad765
assert, revert
tlrmchlsmth Sep 3, 2025
4134e22
wrap sp chunking in a custom op
tlrmchlsmth Sep 4, 2025
8b9dbc7
fixup
tlrmchlsmth Sep 4, 2025
900b951
fixup
tlrmchlsmth Sep 4, 2025
f3754a5
fixup
tlrmchlsmth Sep 4, 2025
898a1c4
rm contiguous
tlrmchlsmth Sep 4, 2025
e6a5908
comment + rename for moe chunking changes
tlrmchlsmth Sep 4, 2025
f6455a9
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 4, 2025
20a3451
fixup comments
tlrmchlsmth Sep 4, 2025
ba075dd
fixup
tlrmchlsmth Sep 5, 2025
cebbc89
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 5, 2025
5d8aaa3
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 5, 2025
a7999c7
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 7, 2025
d92eae0
handle change from #23024
tlrmchlsmth Sep 7, 2025
247ce5b
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 8, 2025
cdd4a39
fixup
tlrmchlsmth Sep 8, 2025
f6e5905
Merge branch 'main' into tp_attn_fix_simple
tlrmchlsmth Sep 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up)

if current_platform.is_cuda_alike():
Expand Down Expand Up @@ -772,6 +772,7 @@ def __init__(
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
):
super().__init__()
if params_dtype is None:
Expand All @@ -783,6 +784,10 @@ def __init__(
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)

self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_

vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
Expand Down Expand Up @@ -1643,6 +1648,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
if self.is_sequence_parallel:
max_tokens_across_dp = cdiv(max_tokens_across_dp, self.sp_size)

num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
Expand Down
109 changes: 85 additions & 24 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config

import vllm.envs as envs
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -71,20 +73,34 @@ def __init__(
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if is_sequence_parallel:
self.gate_up_proj = MergedReplicatedLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = ReplicatedLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")

else:
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
Expand All @@ -99,15 +115,35 @@ def forward(self, x):

class DeepseekV2MoE(nn.Module):

# Chunk x along the num_tokens axis for sequence parallelism
def sp_chunk(self, x):
seq_len = x.size(0)

# all_gather needs the sequence length to be divisible by tp_size
remainder = seq_len % self.tp_size
if remainder != 0:
pad_len = self.tp_size - remainder
pad_shape = list(x.shape)
pad_shape[0] = pad_len
pad = x.new_zeros(pad_shape)
x = torch.cat([x, pad], dim=0)
seq_len = x.size(0)

chunk = seq_len // self.tp_size
start = self.tp_rank * chunk
return x.narrow(0, start, chunk).contiguous()

def __init__(
self,
config: Union[DeepseekV2Config, DeepseekV3Config],
parallel_config: ParallelConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()

self.routed_scaling_factor = config.routed_scaling_factor

self.ep_group = get_ep_group().device_group
Expand All @@ -116,6 +152,15 @@ def __init__(
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts

# If using expert parallel, ensure the input to the experts is
# SP to avoid duplicate work.
# Not needed for pplx-kernels as it can handle duplicate input tokens.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abcdabcd987 and @nandor could you double-check me here: Can pplx handle replicated input tokens in the TP attn + EP MoE case?

self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should call this use_sequence_parallel_mlp since we use seq parallelism for just the mlp layer here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like this because the MLP being sequence parallel is kind of a side effect. And we need to pass it into the fused_moe layer for the chunking.

I'm not a fan of the sequence_parallel name though so definitely open to suggestions

in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)

if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
Expand All @@ -132,9 +177,8 @@ def __init__(
self.gate.e_score_correction_bias = None

# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
eplb_config = parallel_config.eplb_config
self.enable_eplb = parallel_config.enable_eplb

self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
Expand Down Expand Up @@ -162,7 +206,9 @@ def __init__(
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
)

if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
Expand All @@ -172,6 +218,7 @@ def __init__(
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
is_sequence_parallel=self.is_sequence_parallel,
reduce_results=self.experts.must_reduce_shared_expert_outputs(
),
prefix=f"{prefix}.shared_experts",
Expand All @@ -180,8 +227,17 @@ def __init__(
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

# Chunk the hidden states so they aren't replicated across TP ranks.
# This avoids duplicate computation in self.experts.
# TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here.
if self.is_sequence_parallel:
hidden_states = self.sp_chunk(hidden_states)

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

Expand All @@ -194,6 +250,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)

if shared_output is not None:
if hidden_states.dtype != torch.float16:
final_hidden_states = final_hidden_states + shared_output
Expand All @@ -203,7 +260,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)

if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0).contiguous()
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
Expand Down Expand Up @@ -541,10 +602,10 @@ def __init__(
self,
config: Union[DeepseekV2Config, DeepseekV3Config],
prefix: str,
parallel_config: ParallelConfig,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -583,9 +644,9 @@ def __init__(
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = DeepseekV2MLP(
Expand Down Expand Up @@ -658,7 +719,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
parallel_config = vllm_config.parallel_config
self.config = config

self.vocab_size = config.vocab_size
Expand All @@ -677,10 +738,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
lambda prefix: DeepseekV2DecoderLayer(
config,
prefix,
parallel_config=parallel_config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
enable_eplb=enable_eplb,
),
prefix=f"{prefix}.layers")

Expand Down