-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
[Bugfix][Wide EP] Fix redundant work when using DeepEP, TP Attn, and EP MoE #24134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
9cf5b21
46e4ad7
90a2a32
85a1ff0
08bace6
8174f70
02fbb25
76372f2
3b6b3c7
249eb5f
3887a51
31e0c81
c8fb93c
c574cab
090ae53
e9ee2c5
56ad765
4134e22
8b9dbc7
900b951
f3754a5
898a1c4
e6a5908
f6455a9
20a3451
ba075dd
cebbc89
5d8aaa3
a7999c7
d92eae0
247ce5b
cdd4a39
f6e5905
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,12 +32,14 @@ | |
| from torch import nn | ||
| from transformers import DeepseekV2Config, DeepseekV3Config | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.attention import Attention | ||
| from vllm.compilation.decorators import support_torch_compile | ||
| from vllm.config import (CacheConfig, ModelConfig, VllmConfig, | ||
| get_current_vllm_config) | ||
| from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig | ||
| from vllm.distributed import (get_ep_group, get_pp_group, | ||
| get_tensor_model_parallel_world_size) | ||
| get_tensor_model_parallel_rank, | ||
| get_tensor_model_parallel_world_size, | ||
| tensor_model_parallel_all_gather) | ||
| from vllm.model_executor.layers.activation import SiluAndMul | ||
| from vllm.model_executor.layers.fused_moe import FusedMoE | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
|
|
@@ -71,20 +73,34 @@ def __init__( | |
| hidden_act: str, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| reduce_results: bool = True, | ||
| is_sequence_parallel=False, | ||
| prefix: str = "", | ||
| ) -> None: | ||
| super().__init__() | ||
| self.gate_up_proj = MergedColumnParallelLinear( | ||
| hidden_size, [intermediate_size] * 2, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.gate_up_proj") | ||
| self.down_proj = RowParallelLinear(intermediate_size, | ||
| hidden_size, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| reduce_results=reduce_results, | ||
| prefix=f"{prefix}.down_proj") | ||
| if is_sequence_parallel: | ||
| self.gate_up_proj = MergedReplicatedLinear( | ||
| hidden_size, [intermediate_size] * 2, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.gate_up_proj") | ||
| self.down_proj = ReplicatedLinear(intermediate_size, | ||
| hidden_size, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.down_proj") | ||
|
|
||
| else: | ||
| self.gate_up_proj = MergedColumnParallelLinear( | ||
| hidden_size, [intermediate_size] * 2, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.gate_up_proj") | ||
| self.down_proj = RowParallelLinear(intermediate_size, | ||
| hidden_size, | ||
| bias=False, | ||
| quant_config=quant_config, | ||
| reduce_results=reduce_results, | ||
| prefix=f"{prefix}.down_proj") | ||
| if hidden_act != "silu": | ||
| raise ValueError(f"Unsupported activation: {hidden_act}. " | ||
| "Only silu is supported for now.") | ||
|
|
@@ -102,12 +118,14 @@ class DeepseekV2MoE(nn.Module): | |
| def __init__( | ||
| self, | ||
| config: Union[DeepseekV2Config, DeepseekV3Config], | ||
| parallel_config: ParallelConfig, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| enable_eplb: bool = False, | ||
| ): | ||
| super().__init__() | ||
| self.tp_size = get_tensor_model_parallel_world_size() | ||
| self.tp_rank = get_tensor_model_parallel_rank() | ||
|
|
||
| self.routed_scaling_factor = config.routed_scaling_factor | ||
|
|
||
| self.ep_group = get_ep_group().device_group | ||
|
|
@@ -116,6 +134,15 @@ def __init__( | |
| self.n_routed_experts: int = config.n_routed_experts | ||
| self.n_shared_experts: int = config.n_shared_experts | ||
|
|
||
| # If using expert parallel, ensure the input to the experts is | ||
| # SP to avoid duplicate work. | ||
| # Not needed for pplx-kernels as it can handle duplicate input tokens. | ||
|
||
| self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should call this
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I like this because the MLP being sequence parallel is kind of a side effect. And we need to pass it into the fused_moe layer for the chunking. I'm not a fan of the |
||
| in ("deepep_high_throughput", | ||
| "deepep_low_latency") | ||
| and parallel_config.enable_expert_parallel | ||
| and self.tp_size > 1) | ||
|
|
||
| if config.hidden_act != "silu": | ||
| raise ValueError(f"Unsupported activation: {config.hidden_act}. " | ||
| "Only silu is supported for now.") | ||
|
|
@@ -132,9 +159,8 @@ def __init__( | |
| self.gate.e_score_correction_bias = None | ||
|
|
||
| # Load balancing settings. | ||
| vllm_config = get_current_vllm_config() | ||
| eplb_config = vllm_config.parallel_config.eplb_config | ||
| self.enable_eplb = enable_eplb | ||
| eplb_config = parallel_config.eplb_config | ||
| self.enable_eplb = parallel_config.enable_eplb | ||
|
|
||
| self.n_redundant_experts = eplb_config.num_redundant_experts | ||
| self.n_logical_experts = self.n_routed_experts | ||
|
|
@@ -162,7 +188,9 @@ def __init__( | |
| scoring_func=config.scoring_func, | ||
| e_score_correction_bias=self.gate.e_score_correction_bias, | ||
| enable_eplb=self.enable_eplb, | ||
| num_redundant_experts=self.n_redundant_experts) | ||
| num_redundant_experts=self.n_redundant_experts, | ||
| is_sequence_parallel=self.is_sequence_parallel, | ||
| ) | ||
|
|
||
| if config.n_shared_experts is not None: | ||
| intermediate_size = (config.moe_intermediate_size * | ||
|
|
@@ -172,16 +200,44 @@ def __init__( | |
| intermediate_size=intermediate_size, | ||
| hidden_act=config.hidden_act, | ||
| quant_config=quant_config, | ||
| is_sequence_parallel=self.is_sequence_parallel, | ||
| reduce_results=self.experts.must_reduce_shared_expert_outputs( | ||
| ), | ||
| prefix=f"{prefix}.shared_experts", | ||
| ) | ||
|
|
||
| # Chunk x along the num_tokens axis for sequence parallelism | ||
| def sequence_parallel_chunk(self, x: torch.Tensor): | ||
| seq_len = x.size(0) | ||
|
|
||
| # all_gather needs the sequence length to be divisible by tp_size | ||
| remainder = seq_len % self.tp_size | ||
| if remainder != 0: | ||
| pad_len = self.tp_size - remainder | ||
| pad_shape = list(x.shape) | ||
| pad_shape[0] = pad_len | ||
| pad = x.new_zeros(pad_shape) | ||
| x = torch.cat([x, pad], dim=0) | ||
| seq_len = x.size(0) | ||
|
|
||
| chunk = seq_len // self.tp_size | ||
| start = self.tp_rank * chunk | ||
| return x.narrow(0, start, chunk).contiguous() | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| num_tokens, hidden_dim = hidden_states.shape | ||
| hidden_states = hidden_states.view(-1, hidden_dim) | ||
|
|
||
| # Chunk the hidden states so they aren't replicated across TP ranks. | ||
tlrmchlsmth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # This avoids duplicate computation in self.experts. | ||
| # TODO: We can replace the all_reduce at the end of attn with a | ||
| # reduce_scatter instead of chunking here. | ||
| if self.is_sequence_parallel: | ||
| hidden_states = self.sequence_parallel_chunk(hidden_states) | ||
|
|
||
| if self.n_shared_experts is not None: | ||
| shared_output = self.shared_experts(hidden_states) | ||
|
|
||
tlrmchlsmth marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # router_logits: (num_tokens, n_experts) | ||
| router_logits, _ = self.gate(hidden_states) | ||
|
|
||
|
|
@@ -194,6 +250,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| # See DeepseekV2DecoderLayer for more details. | ||
| final_hidden_states = self.experts(hidden_states=hidden_states, | ||
| router_logits=router_logits) | ||
|
|
||
| if shared_output is not None: | ||
| if hidden_states.dtype != torch.float16: | ||
| final_hidden_states = final_hidden_states + shared_output | ||
|
|
@@ -203,7 +260,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| final_hidden_states = final_hidden_states + shared_output \ | ||
| * (1. / self.routed_scaling_factor) | ||
|
|
||
| if self.tp_size > 1: | ||
| if self.is_sequence_parallel: | ||
| final_hidden_states = tensor_model_parallel_all_gather( | ||
| final_hidden_states, 0).contiguous() | ||
| final_hidden_states = final_hidden_states[:num_tokens] | ||
| elif self.tp_size > 1: | ||
| final_hidden_states = ( | ||
| self.experts.maybe_all_reduce_tensor_model_parallel( | ||
| final_hidden_states)) | ||
|
|
@@ -541,10 +602,10 @@ def __init__( | |
| self, | ||
| config: Union[DeepseekV2Config, DeepseekV3Config], | ||
| prefix: str, | ||
| parallel_config: ParallelConfig, | ||
| model_config: ModelConfig, | ||
| cache_config: Optional[CacheConfig] = None, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| enable_eplb: bool = False, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.hidden_size = config.hidden_size | ||
|
|
@@ -583,9 +644,9 @@ def __init__( | |
| and layer_idx % config.moe_layer_freq == 0): | ||
| self.mlp = DeepseekV2MoE( | ||
| config=config, | ||
| parallel_config=parallel_config, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.mlp", | ||
| enable_eplb=enable_eplb, | ||
| ) | ||
| else: | ||
| self.mlp = DeepseekV2MLP( | ||
|
|
@@ -658,7 +719,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| model_config = vllm_config.model_config | ||
| cache_config = vllm_config.cache_config | ||
| quant_config = vllm_config.quant_config | ||
| enable_eplb = vllm_config.parallel_config.enable_eplb | ||
| parallel_config = vllm_config.parallel_config | ||
| self.config = config | ||
|
|
||
| self.vocab_size = config.vocab_size | ||
|
|
@@ -677,10 +738,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
| lambda prefix: DeepseekV2DecoderLayer( | ||
| config, | ||
| prefix, | ||
| parallel_config=parallel_config, | ||
| model_config=model_config, | ||
| cache_config=cache_config, | ||
| quant_config=quant_config, | ||
| enable_eplb=enable_eplb, | ||
| ), | ||
| prefix=f"{prefix}.layers") | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.