diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 272ad3956537..2f88a63665c5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -35,7 +35,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, +from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx, round_up) if current_platform.is_cuda_alike(): @@ -786,6 +786,7 @@ def __init__( enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, + is_sequence_parallel=False, ): super().__init__() if params_dtype is None: @@ -797,6 +798,10 @@ def __init__( dp_size_ = (dp_size if dp_size is not None else get_dp_group().world_size) + self.is_sequence_parallel = is_sequence_parallel + if self.is_sequence_parallel: + self.sp_size = tp_size_ + vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = ( FusedMoEParallelConfig.make( @@ -1699,14 +1704,22 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): ctx = get_forward_context() # flashinfer_cutlass_kernels can handle: optional DP + TP/EP - max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu + max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens + + # If the input to the MoE is sequence parallel then divide by sp_size + # to find the maximum number of tokens for any individual dispatcher. + if self.is_sequence_parallel: + max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers, + self.sp_size) + num_tokens = full_hidden_states.size(0) for chunk_idx, chunk_start_ in enumerate( - range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): + range(0, max_tokens_across_dispatchers, + moe_dp_chunk_size_per_rank)): chunk_start = chunk_start_ chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, - max_tokens_across_dp) + max_tokens_across_dispatchers) # clamp start and end chunk_start = min(chunk_start, num_tokens - 1) chunk_end = min(chunk_end, num_tokens) diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 0c9c83cf6100..5e8447a7f48f 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -37,8 +37,6 @@ def __init__( super().__init__() self.config = vllm_config. \ speculative_config.draft_model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.vocab_size = self.config.vocab_size @@ -51,11 +49,8 @@ def __init__( self.layers = nn.ModuleList([ DeepseekV2DecoderLayer( - self.config, + vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, ) for i in range(self.config.num_hidden_layers) ]) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 0ad001be71c1..8fbf16d206a8 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -7,7 +7,7 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -43,23 +43,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class DeepSeekMultiTokenPredictorLayer(nn.Module): - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config, - cache_config, quant_config) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) def forward( self, @@ -95,13 +91,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # to map the exact layer index from weights self.layers = torch.nn.ModuleDict({ str(idx): - DeepSeekMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) + DeepSeekMultiTokenPredictorLayer(vllm_config, + f"{prefix}.layers.{idx}") for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) }) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d65dcfebaeff..e4a21febc5bd 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -32,12 +32,14 @@ from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -55,7 +57,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import cdiv, direct_register_custom_op from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, @@ -72,19 +76,27 @@ def __init__( hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + is_sequence_parallel=False, prefix: str = "", ) -> None: super().__init__() + + # If is_sequence_parallel, the input and output tensors are sharded + # across the ranks within the tp_group. In this case the weights are + # replicated and no collective ops are needed. + # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.gate_up_proj") self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, + disable_tp=is_sequence_parallel, prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -98,17 +110,58 @@ def forward(self, x): return x +# Chunk x along the num_tokens axis for sequence parallelism +# NOTE: This is wrapped in a torch custom op to work around the following issue: +# The output tensor can have a sequence length 0 at small input sequence lengths +# even though we explicitly pad to avoid this. +def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + # all_gather needs the sequence length to be divisible by tp_size + seq_len = x.size(0) + remainder = seq_len % tp_size + if remainder != 0: + pad_len = tp_size - remainder + x = nn.functional.pad(x, (0, 0, 0, pad_len)) + + chunk = x.shape[0] // tp_size + start = tp_rank * chunk + return torch.narrow(x, 0, start, chunk) + + +def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor: + tp_size = get_tensor_model_parallel_world_size() + seq_len = cdiv(x.size(0), tp_size) + shape = list(x.shape) + shape[0] = seq_len + out = torch.empty(shape, dtype=x.dtype, device=x.device) + return out + + +direct_register_custom_op( + op_name="sequence_parallel_chunk", + op_func=sequence_parallel_chunk, + mutates_args=[], + fake_impl=sequence_parallel_chunk_fake, + dispatch_key=current_platform.dispatch_key, + tags=(torch.Tag.needs_fixed_stride_order, ), +) + + class DeepseekV2MoE(nn.Module): def __init__( self, config: Union[DeepseekV2Config, DeepseekV3Config], + parallel_config: ParallelConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group @@ -117,6 +170,21 @@ def __init__( self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts + # The all_reduce at the end of attention (during o_proj) means that + # inputs are replicated across each rank of the tensor parallel group. + # If using expert-parallelism with DeepEP All2All ops, replicated + # tokens results in useless duplicate computation and communication. + # + # In this case, ensure the input to the experts is sequence parallel + # to avoid the excess work. + # + # Not needed for pplx-kernels as it can handle duplicate input tokens. + self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND + in ("deepep_high_throughput", + "deepep_low_latency") + and parallel_config.enable_expert_parallel + and self.tp_size > 1) + if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") @@ -133,9 +201,8 @@ def __init__( self.gate.e_score_correction_bias = None # Load balancing settings. - vllm_config = get_current_vllm_config() - eplb_config = vllm_config.parallel_config.eplb_config - self.enable_eplb = enable_eplb + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts @@ -166,7 +233,9 @@ def __init__( routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) self.shared_experts = None else: intermediate_size = (config.moe_intermediate_size * @@ -177,6 +246,7 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, reduce_results=False, prefix=f"{prefix}.shared_experts", ) @@ -199,11 +269,22 @@ def __init__( routed_scaling_factor=1.0, e_score_correction_bias=self.gate.e_score_correction_bias, enable_eplb=self.enable_eplb, - num_redundant_experts=self.n_redundant_experts) + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = torch.ops.vllm.sequence_parallel_chunk( + hidden_states) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) @@ -228,7 +309,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: assert shared_output is not None final_hidden_states += shared_output - if self.tp_size > 1: + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: final_hidden_states = ( self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states)) @@ -532,16 +617,15 @@ def forward( class DeepseekV2DecoderLayer(nn.Module): - def __init__( - self, - config: Union[DeepseekV2Config, DeepseekV3Config], - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - enable_eplb: bool = False, - ) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -578,9 +662,9 @@ def __init__( and layer_idx % config.moe_layer_freq == 0): self.mlp = DeepseekV2MoE( config=config, + parallel_config=parallel_config, quant_config=quant_config, prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, ) else: self.mlp = DeepseekV2MLP( @@ -650,10 +734,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size @@ -669,14 +750,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - enable_eplb=enable_eplb, - ), + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: