From 8fe6f82a206abb3db33a26328d28ff856f99a826 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 14 May 2025 16:29:17 -0700 Subject: [PATCH 01/57] [Feature] Core EPLB algorithm Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance.py | 231 +++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 vllm/distributed/eplb/rebalance.py diff --git a/vllm/distributed/eplb/rebalance.py b/vllm/distributed/eplb/rebalance.py new file mode 100644 index 000000000000..6a15d09b7d74 --- /dev/null +++ b/vllm/distributed/eplb/rebalance.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Expert parallelism load balancer (EPLB) for vLLM. + +This module implements the core rearrangement algorithm. + +The rearrangement algorithm is adapted from +[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). +""" + +from typing import Tuple + +import torch + + +def balanced_packing(weight: torch.Tensor, + num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly + n/m objects and the weights of all packs are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange(weight.size(-1), + dtype=torch.int64, + device=weight.device).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, + fill_value=-1, + dtype=torch.int64, + device="cpu") + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min( + (i + for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts( + weight: torch.Tensor, + num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum + load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, + device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_hierarchical( + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange(perm.size(1), dtype=torch.int64, + device=perm.device).expand(perm.shape), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing( + tokens_per_group, num_nodes) + log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) * + group_size).unsqueeze(-1) + + torch.arange(group_size, + dtype=torch.int64, + device=group_pack_index.device)).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes) + phy2mlog, phyrank, mlogcnt = replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, + num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1)).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all + logical experts + num_replicas: number of physical experts, must be a multiple of + `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of + each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica + indices for each expert + expert_count: [layers, num_logical_experts], number of physical + replicasfor each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, + device=log2phy.device).expand(num_layers, -1), + ) + return phy2log, log2phy, logcnt + + +__all__ = ["rebalance_experts"] From bdda8dc94c4a3f5a9f4fca0b36149cc95d563fbe Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 16 May 2025 13:02:06 -0700 Subject: [PATCH 02/57] [Feature] Register expert weights for DeepSeek MoE Signed-off-by: Bowen Wang --- vllm/config.py | 2 + vllm/engine/arg_utils.py | 4 ++ vllm/model_executor/models/deepseek.py | 55 +++++++++++++++++++++--- vllm/model_executor/models/interfaces.py | 53 ++++++++++++++++++++++- 4 files changed, 107 insertions(+), 7 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index c5d61405c839..6c6fa4d2a42e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1690,6 +1690,8 @@ class ParallelConfig: """Port of the data parallel master.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" + num_extra_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 240142a1c5d1..fffbb7a66e92 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -287,6 +287,7 @@ class EngineArgs: data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + num_extra_experts: int = ParallelConfig.num_extra_experts max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -617,6 +618,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--num-extra-experts", + **parallel_kwargs["num_extra_experts"]) parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1062,6 +1065,7 @@ def create_engine_config( data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, + num_extra_experts=self.num_extra_experts, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index c6421143dd68..c9d381edc422 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -22,17 +22,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +import typing +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_ep_group from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -49,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import IsMixtureOfExperts, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -101,13 +103,27 @@ def __init__( self.config = config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() + self.ep_group = get_ep_group() + self.ep_rank = self.ep_group.rank + self.ep_size = self.ep_group.world_size self.n_routed_experts = config.n_routed_experts + self.n_shared_experts = config.n_shared_experts self.top_k = config.num_experts_per_tok if self.tp_size > self.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.n_routed_experts}.") + # Load balancing settings. + # Currently, `n_redundancy_expers` equals to `n_extra_experts`. + vllm_config = get_current_vllm_config() + self.n_extra_experts = vllm_config.parallel_config.num_extra_experts + self.n_physical_experts = self.n_routed_experts + self.n_extra_experts + self.n_logical_experts = self.n_routed_experts + self.n_redundancy_expers = (self.n_physical_experts - + self.n_logical_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.experts = nn.ModuleList([ DeepseekMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, @@ -134,6 +150,24 @@ def __init__( reduce_results=False, ) + def get_weights(self) -> List[torch.Tensor]: + ret: List[torch.Tensor] = [] + for weight in [self.gate_proj_weight, self.down_proj_weight]: + weight = typing.cast( + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], weight) + if isinstance(weight, torch.Tensor): + assert weight.is_contiguous() + ret.append(weight.view(self.n_local_physical_experts, -1)) + else: + # FP8 weights + assert weight[0].element_size() == 1 + assert weight[0].is_contiguous() + assert weight[1].is_contiguous() + ret.append(weight[0].view(torch.int8).view( + self.n_local_physical_experts, -1)) + ret.append(weight[1].view(self.n_local_physical_experts, -1)) + return ret + def pack_params(self): w1 = [] w2 = [] @@ -436,7 +470,7 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class DeepseekForCausalLM(nn.Module, SupportsPP): +class DeepseekForCausalLM(nn.Module, SupportsPP, IsMixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -454,6 +488,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights = [] def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -481,4 +516,14 @@ def compute_logits( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) \ No newline at end of file + loaded_weights = loader.load_weights(weights) + + # Register the expert weights. + for layer in self.model.layers: + assert isinstance(layer, DeepseekDecoderLayer) + if isinstance(layer.mlp, DeepseekMoE): + self.expert_weights.append(layer.mlp.get_weights()) + + # TODO(bowen): Add support for MTP layers + + return loaded_weights diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 7fea9647ead9..42af98809147 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, - Protocol, Type, Union, overload, runtime_checkable) +from collections.abc import Iterable, Sequence +from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, + MutableSequence, Optional, Protocol, Type, Union, overload, + runtime_checkable) import torch from torch import Tensor @@ -423,6 +425,53 @@ def is_hybrid( return isinstance(model, IsHybrid) +@runtime_checkable +class IsMixtureOfExperts(Protocol): + """ + Check if the model is a mixture of experts (MoE) model. + """ + + is_mixture_of_experts: ClassVar[Literal[True]] = True + """ + A flag that indicates this model is a mixture of experts (MoE) model. + Used for expert parallel load balancing (EPLB) now. + """ + + expert_weights: MutableSequence[Iterable[Tensor]] + """ + Expert weights saved in this rank. + + The first dimension is the layer, and the second dimension is different + parameters in the layer, e.g. up/down projection weights. + """ + + +@runtime_checkable +class _IsMixtureOfExpertsType(Protocol): + is_mixture_of_experts: ClassVar[Literal[True]] + expert_weights: Sequence[Iterable[Tensor]] + + +@overload +def is_mixture_of_experts(model: object) -> TypeIs[IsMixtureOfExperts]: + ... + + +@overload +def is_mixture_of_experts( + model: Type[object]) -> TypeIs[Type[IsMixtureOfExperts]]: + ... + + +def is_mixture_of_experts( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsMixtureOfExperts]], TypeIs[IsMixtureOfExperts]]: + if isinstance(model, type): + return isinstance(model, _IsMixtureOfExpertsType) + + return isinstance(model, IsMixtureOfExperts) + + @runtime_checkable class HasNoOps(Protocol): has_noops: ClassVar[Literal[True]] = True From 43d52ac080afc1d292a1ca58dd0d2ef44c18931b Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 16 May 2025 13:08:43 -0700 Subject: [PATCH 03/57] [Chore] Rename EPLB rebalance algo module name Signed-off-by: Bowen Wang --- vllm/distributed/eplb/{rebalance.py => rebalance_algo.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename vllm/distributed/eplb/{rebalance.py => rebalance_algo.py} (100%) diff --git a/vllm/distributed/eplb/rebalance.py b/vllm/distributed/eplb/rebalance_algo.py similarity index 100% rename from vllm/distributed/eplb/rebalance.py rename to vllm/distributed/eplb/rebalance_algo.py From 58bf9fdef9830ae4880ba51185d5340df262fdfd Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 16 May 2025 15:08:06 -0700 Subject: [PATCH 04/57] [Feature] Store EPLB states in model runner Signed-off-by: Bowen Wang --- vllm/config.py | 21 ++++ vllm/distributed/eplb/__init__.py | 4 + vllm/distributed/eplb/states.py | 148 +++++++++++++++++++++++ vllm/engine/arg_utils.py | 10 ++ vllm/model_executor/models/deepseek.py | 22 +++- vllm/model_executor/models/interfaces.py | 39 +++--- vllm/v1/worker/gpu_model_runner.py | 24 ++++ 7 files changed, 241 insertions(+), 27 deletions(-) create mode 100644 vllm/distributed/eplb/__init__.py create mode 100644 vllm/distributed/eplb/states.py diff --git a/vllm/config.py b/vllm/config.py index 6c6fa4d2a42e..802ee2a36361 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1688,10 +1688,17 @@ class ParallelConfig: """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" + enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" + enable_eplb: bool = False + """Enable expert parallelism load balancing for MoE layers.""" num_extra_experts: int = 0 """Number of redundant experts to use for expert parallelism.""" + eplb_window_size: int = 1000 + """Window size for expert load recording.""" + eplb_step_interval: int = 3000 + """Interval for rearranging experts in expert parallelism.""" max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model @@ -1836,6 +1843,20 @@ def __post_init__(self) -> None: f"{current_platform.device_type.upper()} backend only " "supports Ray for distributed inference.") + if self.enable_eplb: + if not current_platform.is_cuda(): + raise ValueError( + "Expert parallelism load balancing is only supported on " + "CUDA devices now.") + if self.num_extra_experts < 0: + raise ValueError( + "num_extra_experts must be non-negative, but got " + f"{self.num_extra_experts}.") + else: + if self.num_extra_experts != 0: + raise ValueError("num_extra_experts should be used with EPLB." + f"{self.num_extra_experts}.") + if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py new file mode 100644 index 000000000000..be6006b4ecb8 --- /dev/null +++ b/vllm/distributed/eplb/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .rebalance_algo import * +from .states import * diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py new file mode 100644 index 000000000000..a0564aaafd63 --- /dev/null +++ b/vllm/distributed/eplb/states.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Expert parallelism load balancer (EPLB) metrics and states. +""" + +from dataclasses import dataclass + +import torch + +from vllm.config import ParallelConfig +from vllm.model_executor.models.interfaces import IsMixtureOfExperts + + +@dataclass +class EplbState: + """EPLB metrics.""" + + physical_to_logical_map: torch.Tensor + """ + Mapping from physical experts to logical experts. + + Shape: (num_moe_layers, num_physical_experts) + """ + logical_to_physical_map: torch.Tensor + """ + Mapping from logical experts to physical experts. + + This is a sparse matrix, where -1 indicates no mapping. + + Shape: (num_moe_layers, num_logical_experts) + """ + logical_replica_count: torch.Tensor + """ + Number of replicas for each logical expert. + + Shape: (num_moe_layers, num_logical_experts) + """ + + expert_load_pass: torch.Tensor + """ + Expert load during this forward pass. + We use the token count each expert processes as the load. + + Shape: (num_moe_layers, num_local_physical_experts) + """ + expert_load_window: torch.Tensor + """ + A sliding window of expert load. + + Shape: (window_size, num_moe_layers, num_local_physical_experts) + """ + expert_load_window_step: int = 0 + """Current step in the sliding window.""" + + expert_rearrangement_step: int = 0 + """ + Steps after last rearrangement. + Will trigger a rearrangement if it exceeds the threshold. + """ + + @staticmethod + def build_initial_global_physical_to_logical_map( + num_routed_experts: int, + num_redundant_experts: int, + ) -> list[int]: + """ + Build an initial expert arrangement using the following structure: + [original routed experts, redundant experts] + """ + global_physical_to_logical_map = list(range(num_routed_experts)) + global_physical_to_logical_map += [ + i % num_routed_experts for i in range(num_redundant_experts) + ] + return global_physical_to_logical_map + + @classmethod + def build( + cls, + model: IsMixtureOfExperts, + device: torch.device, + parallel_config: ParallelConfig, + ) -> "EplbState": + """ + Build the initial EPLB state. + """ + physical_to_logical_map = ( + cls.build_initial_global_physical_to_logical_map( + model.num_routed_experts, + model.num_redundant_experts, + )) + physical_to_logical_map = torch.tensor( + physical_to_logical_map, + device=device, + ) + logical_to_physical_map = torch.full( + (model.num_logical_experts, model.num_redundant_experts + 1), + -1, + device=device, + ) + logical_replica_count = torch.zeros( + (model.num_logical_experts, ), + device=device, + ) + + for i in range(model.num_physical_experts): + logical_idx = physical_to_logical_map[i] + logical_to_physical_map[logical_idx, + logical_replica_count[logical_idx]] = i + logical_replica_count[logical_idx] += 1 + + # Duplicate initial mapping for all layers + physical_to_logical_map = physical_to_logical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + logical_to_physical_map = logical_to_physical_map.unsqueeze(0).expand( + model.num_moe_layers, + -1, + -1, + ).contiguous() + logical_replica_count = logical_replica_count.unsqueeze(0).expand( + model.num_moe_layers, + -1, + ).contiguous() + + expert_load_pass = torch.zeros( + (model.num_moe_layers, model.num_local_physical_experts), + device=device, + ) + expert_load_window = torch.zeros( + (parallel_config.eplb_window_size, model.num_moe_layers, + model.num_local_physical_experts), + device=device, + ) + + # Set the initial progress of rearrangement to 3/4 + eplb_step_interval = parallel_config.eplb_step_interval + expert_rearrangement_step = max( + 0, eplb_step_interval - eplb_step_interval // 4) + + return EplbState( + physical_to_logical_map, + logical_to_physical_map, + logical_replica_count, + expert_load_pass, + expert_load_window, + expert_rearrangement_step=expert_rearrangement_step, + ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fffbb7a66e92..00bf39cc6c13 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -287,7 +287,10 @@ class EngineArgs: data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + enable_eplb: bool = ParallelConfig.enable_eplb num_extra_experts: int = ParallelConfig.num_extra_experts + eplb_window_size: int = ParallelConfig.eplb_window_size + eplb_step_interval: int = ParallelConfig.eplb_step_interval max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -618,8 +621,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) + parallel_group.add_argument("--enable-eplb", + **parallel_kwargs["enable_eplb"]) parallel_group.add_argument("--num-extra-experts", **parallel_kwargs["num_extra_experts"]) + parallel_group.add_argument("--eplb-window-size", + **parallel_kwargs["eplb_window_size"]) + parallel_group.add_argument("--eplb-step-interval", + **parallel_kwargs["eplb_step_interval"]) parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1065,6 +1074,7 @@ def create_engine_config( data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, + enable_eplb=self.enable_eplb, num_extra_experts=self.num_extra_experts, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index c9d381edc422..4c7279f8f275 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -106,8 +106,8 @@ def __init__( self.ep_group = get_ep_group() self.ep_rank = self.ep_group.rank self.ep_size = self.ep_group.world_size - self.n_routed_experts = config.n_routed_experts - self.n_shared_experts = config.n_shared_experts + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts self.top_k = config.num_experts_per_tok if self.tp_size > self.n_routed_experts: raise ValueError( @@ -115,12 +115,12 @@ def __init__( f"the number of experts {self.n_routed_experts}.") # Load balancing settings. - # Currently, `n_redundancy_expers` equals to `n_extra_experts`. + # Currently, `n_redundant_experts` equals to `n_extra_experts`. vllm_config = get_current_vllm_config() self.n_extra_experts = vllm_config.parallel_config.num_extra_experts self.n_physical_experts = self.n_routed_experts + self.n_extra_experts self.n_logical_experts = self.n_routed_experts - self.n_redundancy_expers = (self.n_physical_experts - + self.n_redundant_experts = (self.n_physical_experts - self.n_logical_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size @@ -490,6 +490,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors) self.expert_weights = [] + # Set MoE hyperparameters + # TODO(bowen): Add support for MTP layers + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + + example_moe = typing.cast( + DeepseekMoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 42af98809147..d62bbe8fc0cc 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, MutableSequence, Optional, Protocol, Type, Union, overload, runtime_checkable) @@ -431,12 +431,6 @@ class IsMixtureOfExperts(Protocol): Check if the model is a mixture of experts (MoE) model. """ - is_mixture_of_experts: ClassVar[Literal[True]] = True - """ - A flag that indicates this model is a mixture of experts (MoE) model. - Used for expert parallel load balancing (EPLB) now. - """ - expert_weights: MutableSequence[Iterable[Tensor]] """ Expert weights saved in this rank. @@ -445,30 +439,29 @@ class IsMixtureOfExperts(Protocol): parameters in the layer, e.g. up/down projection weights. """ + num_moe_layers: int + """Number of MoE layers in this model.""" -@runtime_checkable -class _IsMixtureOfExpertsType(Protocol): - is_mixture_of_experts: ClassVar[Literal[True]] - expert_weights: Sequence[Iterable[Tensor]] + num_logical_experts: int + """Number of logical experts in this model.""" + num_physical_experts: int + """Number of physical experts in this model.""" -@overload -def is_mixture_of_experts(model: object) -> TypeIs[IsMixtureOfExperts]: - ... + num_local_physical_experts: int + """Number of local physical experts in this model.""" + num_routed_experts: int + """Number of routed experts in this model.""" -@overload -def is_mixture_of_experts( - model: Type[object]) -> TypeIs[Type[IsMixtureOfExperts]]: - ... + num_shared_experts: int + """Number of shared experts in this model.""" + num_redundant_experts: int + """Number of redundant experts in this model.""" -def is_mixture_of_experts( - model: Union[Type[object], object] -) -> Union[TypeIs[Type[IsMixtureOfExperts]], TypeIs[IsMixtureOfExperts]]: - if isinstance(model, type): - return isinstance(model, _IsMixtureOfExpertsType) +def is_mixture_of_experts(model: object) -> TypeIs[IsMixtureOfExperts]: return isinstance(model, IsMixtureOfExperts) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b16f273a6de..9c8c166b6ddb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -16,6 +16,7 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) +from vllm.distributed.eplb.states import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 @@ -25,6 +26,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -67,6 +69,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): + enable_eplb: bool = False + """ + Whether the expert parallelism load balancer is enabled. + """ + eplb_state: EplbState + """ + State of the expert parallelism load balancer. + + Will be lazily initialized when the model is loaded. + """ + def __init__( self, vllm_config: VllmConfig, @@ -1460,6 +1473,17 @@ def load_model(self) -> None: time_after_load - time_before_load) prepare_communication_buffer_for_model(self.model) + if is_mixture_of_experts( + self.model) and self.parallel_config.enable_eplb: + self.enable_eplb = True + logger.info("EPLB is enabled for model %s.", + self.model_config.model) + self.eplb_state = EplbState.build( + self.model, + self.device, + self.parallel_config, + ) + def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, From 52b141f22e1e02b011bc15bac16b90ccf145932a Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 16 May 2025 16:39:59 -0700 Subject: [PATCH 05/57] [Feature] EPLB rearrangement execution Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance_execute.py | 280 +++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 vllm/distributed/eplb/rebalance_execute.py diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py new file mode 100644 index 000000000000..57fbf2b287f2 --- /dev/null +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +The actual execution of the rearrangement. + +This involves the exchange of expert weights between GPUs. +""" + +from collections.abc import Iterable, Sequence +from functools import partial +from typing import Dict, List, MutableSequence, Tuple + +import torch +from torch.distributed import (P2POp, ProcessGroup, batch_isend_irecv, + get_global_rank) + + +def idx_local_to_global( + local_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a local expert index to a global expert index. + """ + return ep_rank * local_cnt + local_idx + + +def idx_global_to_local( + global_idx: int, + local_cnt: int, + ep_rank: int, +) -> int: + """ + Convert a global expert index to a local expert index. + """ + return global_idx - ep_rank * local_cnt + + +def global_idx_to_rank( + global_idx: int, + local_cnt: int, +) -> int: + """ + Convert a global expert index to a rank index. + """ + return global_idx // local_cnt + + +def get_ep_ranks_with_expert( + idx: int, + num_local_experts: int, + old_indices: Sequence[int], + new_indices: Sequence[int], +) -> Tuple[MutableSequence[int], MutableSequence[int]]: + """ + Get the ranks of the experts that need to be exchanged. + + Args: + idx: The index of the expert. + num_local_experts: The number of local experts. + old_indices: The old indices of the experts. + new_indices: The new indices of the experts. + + Returns: + A tuple of two lists: + - The ranks of the experts that need to be sent. + - The ranks of the experts that need to be received. + """ + global2rank = partial( + global_idx_to_rank, + local_cnt=num_local_experts, + ) + + ranks_to_send: List[int] = [] + ranks_to_recv: List[int] = [] + + for i, e in enumerate(old_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_send or ranks_to_send[-1] != rank: + ranks_to_send.append(rank) + + for i, e in enumerate(new_indices): + if e == idx: + rank = global2rank(i) + if not ranks_to_recv or ranks_to_recv[-1] != rank: + ranks_to_recv.append(rank) + + # Remove those ranks that can get this expert locally. + ranks_to_send_set = set(ranks_to_send) + ranks_to_recv_actual = [ + rank for rank in ranks_to_recv if rank not in ranks_to_send_set + ] + + return ranks_to_send, ranks_to_recv_actual + + +def shuffle_layer( + num_local_experts: int, + ep_rank: int, + old_indices: Sequence[int], + new_indices: Sequence[int], + expert_weights: Iterable[torch.Tensor], + expert_weights_buffer: Sequence[torch.Tensor], + ep_group: ProcessGroup, +) -> None: + """ + Perform expert weights rearrangement of one layer. + """ + local2global = partial( + idx_local_to_global, + local_cnt=num_local_experts, + ep_rank=ep_rank, + ) + + # 0. Do nothing for experts that did not change. + is_unchanged = [ + old_indices[local2global(i)] == new_indices[local2global(i)] + for i in range(num_local_experts) + ] + + # 1. Perform weight copy inside the local rank. + is_received_locally = is_unchanged[:] + for src in range(num_local_experts): + src_global = local2global(src) + for dst in range(num_local_experts): + dst_global = local2global(dst) + if is_received_locally[dst]: + continue + if old_indices[src_global] == new_indices[dst_global]: + is_received_locally[dst] = True + for weight, buffer in zip(expert_weights, + expert_weights_buffer): + buffer[dst].copy_(weight[src]) + + p2p_ops: List[P2POp] = [] + + # 2. Initiate sending of weights. + experts_send_loc: Dict[int, int] = {} + for src in range(num_local_experts): + expert = old_indices[local2global(src)] + if expert in experts_send_loc: + continue + experts_send_loc[expert] = src + + # We need to sort here to match send/recv + for expert, src in sorted(experts_send_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the ranks to send by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + sender_pos = ranks_to_send.index(ep_rank) + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + recv_ranks = ranks_to_recv[recv_begin:recv_end] + + # Tackle remainders + remainder_start = len(ranks_to_send) * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < len(ranks_to_recv): + recv_ranks.append(ranks_to_recv[recver_pos]) + + for dst in recv_ranks: + dst_global = get_global_rank(ep_group, dst) + p2p_ops += [ + P2POp(torch.distributed.isend, weight[dst], dst_global) + for weight in expert_weights + ] + + # 3. Initiate receiving of weights. + experts_recv_loc: Dict[int, int] = {} + for dst in range(num_local_experts): + if is_received_locally[dst]: + continue + expert = new_indices[local2global(dst)] + if expert in experts_recv_loc: + continue + experts_recv_loc[expert] = dst + + # We need to sort here to match send/recv + for expert, dst in sorted(experts_recv_loc.items()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + new_indices, + ) + + # Calculate the rank to recv by this rank + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + recver_pos = ranks_to_recv.index(ep_rank) + remainder_start = len(ranks_to_send) * num_dst_per_sender + if recver_pos < remainder_start: + src = ranks_to_send[recver_pos // num_dst_per_sender] + else: + src = ranks_to_send[recver_pos - remainder_start] + + src_global = get_global_rank(ep_group, src) + p2p_ops += [ + P2POp( + torch.distributed.irecv, + weight[dst], + src_global, + ) for weight in expert_weights_buffer + ] + + if p2p_ops: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + + # 4. Copy the weights from the buffer back to the original weights. + for dst in range(num_local_experts): + if is_unchanged[dst]: + continue + if is_received_locally[dst]: + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[dst]) + else: + expert = new_indices[local2global(dst)] + src = experts_recv_loc[expert] + for weight, buffer in zip(expert_weights, expert_weights_buffer): + weight[dst].copy_(buffer[src]) + + +def rearrange_expert_weights_inplace( + old_global_expert_indices: torch.Tensor, + new_global_expert_indices: torch.Tensor, + expert_weights: Sequence[Iterable[torch.Tensor]], + ep_group: ProcessGroup, +) -> None: + """ + Rearranges the expert weights in place according to the new expert indices. + + The value of the indices arguments are logical indices of the experts, + while keys are physical. + + Args: + old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). + expert_weights: A sequence of shape (num_moe_layers)(weight_count) + of tensors of shape (num_local_physical_experts, hidden_size_i). + For example, a linear layer may have up and down projection, + so weight_count = 2. Each weight's hidden size can be different. + ep_group: The device process group for expert parallelism. + """ + num_moe_layers, num_physical_experts = old_global_expert_indices.shape + assert len(expert_weights) == num_moe_layers + + num_local_physical_experts = next(iter(expert_weights[0])).shape[0] + assert new_global_expert_indices.shape == (num_moe_layers, + num_local_physical_experts) + + ep_rank = ep_group.rank() + ep_size = ep_group.size() + assert num_physical_experts == ep_size * num_local_physical_experts + + # A buffer to hold the expert weights in one layer during the exchange. + # NOTE: Currently we assume the same weights across different layers + # have the same shape. + expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] + + for layer in range(num_moe_layers): + shuffle_layer( + num_local_physical_experts, + ep_rank, + old_global_expert_indices[layer].tolist(), + new_global_expert_indices[layer].tolist(), + expert_weights[layer], + expert_weights_buffer, + ep_group, + ) + + +__all__ = ["rearrange_expert_weights_inplace"] From 98312d38da55e7cfdc1501c6366c005ea9c94a8f Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 19 May 2025 00:14:42 -0700 Subject: [PATCH 06/57] [Feature] Add expert load metrics collection during forward WIP, design choices not finalized. Signed-off-by: Bowen Wang --- vllm/forward_context.py | 9 +++- .../layers/fused_moe/fused_moe.py | 53 +++++++++++-------- vllm/model_executor/models/deepseek.py | 41 +++++++++++--- vllm/v1/worker/gpu_model_runner.py | 8 ++- 4 files changed, 77 insertions(+), 34 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index eb1e1f5694bb..6609b0178e6d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -45,6 +45,9 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + # Used for EPLB to collect expert load metrics. + # TODO(bowen): see if we can find a better accommodation for this + expert_load_pass: Optional[torch.Tensor] = None _forward_context: Optional[ForwardContext] = None @@ -62,7 +65,8 @@ def get_forward_context() -> ForwardContext: def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0): + num_tokens: int = 0, + expert_load_pass: Optional[torch.Tensor] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -100,7 +104,8 @@ def set_forward_context(attn_metadata: Any, static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, - dp_metadata=dp_metadata) + dp_metadata=dp_metadata, + expert_load_pass=expert_load_pass) try: yield diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 8c28cedbcd77..25013a0591e0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -3,7 +3,7 @@ import functools import json import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch @@ -1453,7 +1453,8 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, -) -> torch.Tensor: + return_topk_ids: bool = False, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -1499,6 +1500,7 @@ def fused_moe( a2. - block_shape: (Optional[list[int]]): Optional block size for block-wise quantization. + - return_topk_ids (bool): If True, return the top-k expert IDs Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -1516,24 +1518,29 @@ def fused_moe( topk_weights, topk_ids = custom_routing_function( hidden_states, gating_output, topk, renormalize) - return fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) + result = fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) + + if return_topk_ids: + return result, topk_ids + else: + return result diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 4c7279f8f275..3bb153617707 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -35,6 +35,7 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_ep_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm @@ -95,11 +96,13 @@ class DeepseekMoE(nn.Module): def __init__( self, + layer_idx: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() + self.layer_idx = layer_idx self.config = config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -124,12 +127,18 @@ def __init__( self.n_logical_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = nn.ModuleList([ DeepseekMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False) + # TODO(bowen): Work together with EP for idx in range(self.n_routed_experts) ]) self.pack_params() @@ -194,13 +203,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe(hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True) + final_hidden_states, topk_ids = fused_moe( + hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True, + return_topk_ids=True) + + # Collect expert load statistics + forward_context = get_forward_context() + if forward_context.expert_load_pass is not None: + expert_load_pass = forward_context.expert_load_pass + local_mask = (topk_ids >= self.physical_expert_start) & ( + topk_ids < self.physical_expert_end) + if local_mask.any(): + local_indices = topk_ids[ + local_mask] - self.physical_expert_start + counts = torch.bincount( + local_indices, minlength=self.n_local_physical_experts) + expert_load_pass[self.layer_idx] += counts if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output @@ -321,7 +345,8 @@ def __init__( if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, + self.mlp = DeepseekMoE(layer_idx=layer_idx, + config=config, quant_config=quant_config, prefix=f"{prefix}.mlp") else: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c8c166b6ddb..a1665b27fac8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1645,9 +1645,15 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) + if self.parallel_config.enable_eplb: + expert_load_pass = self.eplb_state.expert_load_pass + else: + expert_load_pass = None + with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_tokens): + num_tokens=num_tokens, + expert_load_pass=expert_load_pass): outputs = model( input_ids=input_ids, positions=positions, From 22a963d36cc83fa8986311c780a8b94b6968a322 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 19 May 2025 01:18:05 -0700 Subject: [PATCH 07/57] [Feature] Rearrange experts after a preset step interval Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 111 ++++++++++++++++++++++- vllm/model_executor/models/deepseek.py | 1 + vllm/model_executor/models/interfaces.py | 3 + vllm/v1/worker/gpu_model_runner.py | 6 ++ 4 files changed, 120 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index a0564aaafd63..9650797b2b2f 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -3,13 +3,22 @@ Expert parallelism load balancer (EPLB) metrics and states. """ +import time from dataclasses import dataclass import torch +from torch.distributed import all_reduce from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_ep_group +from vllm.logger import init_logger from vllm.model_executor.models.interfaces import IsMixtureOfExperts +from .rebalance import rebalance_experts +from .rebalance_execute import rearrange_expert_weights_inplace + +logger = init_logger(__name__) + @dataclass class EplbState: @@ -51,12 +60,16 @@ class EplbState: """ expert_load_window_step: int = 0 """Current step in the sliding window.""" + expert_load_window_size: int = 0 + """Size of the expert load sliding window.""" expert_rearrangement_step: int = 0 """ Steps after last rearrangement. Will trigger a rearrangement if it exceeds the threshold. """ + expert_rearrangement_step_interval: int = 0 + """Interval for expert rearrangement steps.""" @staticmethod def build_initial_global_physical_to_logical_map( @@ -127,8 +140,9 @@ def build( (model.num_moe_layers, model.num_local_physical_experts), device=device, ) + expert_load_window_size = parallel_config.eplb_window_size expert_load_window = torch.zeros( - (parallel_config.eplb_window_size, model.num_moe_layers, + (expert_load_window_size, model.num_moe_layers, model.num_local_physical_experts), device=device, ) @@ -144,5 +158,100 @@ def build( logical_replica_count, expert_load_pass, expert_load_window, + expert_load_window_size=expert_load_window_size, expert_rearrangement_step=expert_rearrangement_step, + expert_rearrangement_step_interval=eplb_step_interval, + ) + + def step(self, model: IsMixtureOfExperts) -> None: + """ + Step the EPLB state. + """ + + # Update the expert load sliding window + self.expert_load_window[self.expert_load_window_step] = ( + self.expert_load_pass.clone()) + self.expert_load_window_step += 1 + if self.expert_load_window_step >= self.expert_load_window_size: + self.expert_load_window_step = 0 + self.expert_load_pass.zero_() + + # Step the expert rearrangement step + self.expert_rearrangement_step += 1 + if (self.expert_rearrangement_step + >= self.expert_rearrangement_step_interval): + self.expert_rearrangement_step = 0 + self.rearrange(model) + + def rearrange(self, model: IsMixtureOfExperts) -> None: + """ + Rearrange the experts according to the current load. + """ + + ep_group = get_ep_group() + ep_rank = ep_group.rank() + is_main_rank = ep_rank == 0 + if is_main_rank: + torch.cuda.synchronize() + time_start = time.perf_counter() + logger.info("Rearranging experts...") + + window_size, num_moe_layers, num_local_physical_experts = ( + self.expert_load_window.shape) + num_physical_experts = model.num_physical_experts + + local_expert_start = ep_rank * num_local_physical_experts + local_expert_end = local_expert_start + num_local_physical_experts + local_physical_to_logical_map = self.physical_to_logical_map[:, + local_expert_start: + local_expert_end] + device = local_physical_to_logical_map.device + + # Perform all-reduce to get the expert load across all ranks + expert_load_window = self.expert_load_window + global_expert_load_window = torch.zeros(window_size, + num_moe_layers, + num_physical_experts, + device=device) + global_expert_load_window.scatter_add_( + -1, + local_physical_to_logical_map.expand_as(expert_load_window), + expert_load_window, ) + all_reduce(global_expert_load_window, group=ep_group) + + # TODO(bowen): Treat differently for prefill and decode nodes + num_replicas = num_physical_experts + num_groups = model.num_expert_groups + # TODO(bowen): Remove magic numbers + num_nodes = (ep_group.size() + 7) // 8 + num_gpus = ep_group.size() + + # Get new expert mappings + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = (rebalance_experts( + global_expert_load_window, + num_replicas, + num_groups, + num_nodes, + num_gpus, + )) + + # Update expert weights + rearrange_expert_weights_inplace( + self.physical_to_logical_map, + new_physical_to_logical_map, + model.expert_weights, + ep_group, + ) + + if is_main_rank: + torch.cuda.synchronize() + time_end = time.perf_counter() + logger.info( + "Rearranged experts in %.2f seconds.", + time_end - time_start, + ) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 3bb153617707..98f76a5af58b 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -519,6 +519,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # TODO(bowen): Add support for MTP layers self.num_moe_layers = (config.num_hidden_layers - config.first_k_dense_replace) + self.num_expert_groups = config.n_groups example_moe = typing.cast( DeepseekMoE, self.model.layers[config.num_hidden_layers - 1].mlp) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d62bbe8fc0cc..0022a6b7ab30 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -442,6 +442,9 @@ class IsMixtureOfExperts(Protocol): num_moe_layers: int """Number of MoE layers in this model.""" + num_expert_groups: int + """Number of expert groups in this model.""" + num_logical_experts: int """Number of logical experts in this model.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1665b27fac8..fe605512f679 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1354,6 +1354,12 @@ def execute_model( if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() + # EPLB step + if self.parallel_config.enable_eplb: + self.eplb_state.step(self.model) + + # TODO(bowen): Log balancedness + return ModelRunnerOutput( req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, From 43ac672339808a96b4b85787d1c674017cc9d93c Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 20 May 2025 18:02:19 +0000 Subject: [PATCH 08/57] [Feature] Use unified `FusedMoE` in DeepSeek-V3/R1 Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 10 +-- vllm/model_executor/models/deepseek.py | 88 +++++++++--------------- vllm/model_executor/models/interfaces.py | 6 +- 3 files changed, 39 insertions(+), 65 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 9650797b2b2f..8e455d81a730 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -12,9 +12,9 @@ from vllm.config import ParallelConfig from vllm.distributed.parallel_state import get_ep_group from vllm.logger import init_logger -from vllm.model_executor.models.interfaces import IsMixtureOfExperts +from vllm.model_executor.models.interfaces import MixtureOfExperts -from .rebalance import rebalance_experts +from .rebalance_algo import rebalance_experts from .rebalance_execute import rearrange_expert_weights_inplace logger = init_logger(__name__) @@ -89,7 +89,7 @@ def build_initial_global_physical_to_logical_map( @classmethod def build( cls, - model: IsMixtureOfExperts, + model: MixtureOfExperts, device: torch.device, parallel_config: ParallelConfig, ) -> "EplbState": @@ -163,7 +163,7 @@ def build( expert_rearrangement_step_interval=eplb_step_interval, ) - def step(self, model: IsMixtureOfExperts) -> None: + def step(self, model: MixtureOfExperts) -> None: """ Step the EPLB state. """ @@ -183,7 +183,7 @@ def step(self, model: IsMixtureOfExperts) -> None: self.expert_rearrangement_step = 0 self.rearrange(model) - def rearrange(self, model: IsMixtureOfExperts) -> None: + def rearrange(self, model: MixtureOfExperts) -> None: """ Rearrange the experts according to the current load. """ diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index ddddb60089e8..964588ad91e6 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -36,9 +36,8 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_ep_group -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -53,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import IsMixtureOfExperts, SupportsPP +from .interfaces import MixtureOfExperts, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -133,22 +132,33 @@ def __init__( self.physical_expert_end = (self.physical_expert_start + self.n_local_physical_experts) - self.experts = nn.ModuleList([ - DeepseekMLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False) - # TODO(bowen): Work together with EP - for idx in range(self.n_routed_experts) - ]) - self.pack_params() - self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, quant_config=None) + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = typing.cast(Any, None) + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_groups, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -157,7 +167,9 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=False, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_experts", ) def get_weights(self) -> list[torch.Tensor]: @@ -178,25 +190,6 @@ def get_weights(self) -> list[torch.Tensor]: ret.append(weight[1].view(self.n_local_physical_experts, -1)) return ret - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -204,28 +197,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states, topk_ids = fused_moe( - hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - inplace=True, - return_topk_ids=True) - - # Collect expert load statistics - forward_context = get_forward_context() - if forward_context.expert_load_pass is not None: - expert_load_pass = forward_context.expert_load_pass - local_mask = (topk_ids >= self.physical_expert_start) & ( - topk_ids < self.physical_expert_end) - if local_mask.any(): - local_indices = topk_ids[ - local_mask] - self.physical_expert_start - counts = torch.bincount( - local_indices, minlength=self.n_local_physical_experts) - expert_load_pass[self.layer_idx] += counts + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output @@ -496,7 +470,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class DeepseekForCausalLM(nn.Module, SupportsPP, IsMixtureOfExperts): +class DeepseekForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 0f9f22b38a94..afa813ee89cc 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -425,7 +425,7 @@ def is_hybrid( @runtime_checkable -class IsMixtureOfExperts(Protocol): +class MixtureOfExperts(Protocol): """ Check if the model is a mixture of experts (MoE) model. """ @@ -463,8 +463,8 @@ class IsMixtureOfExperts(Protocol): """Number of redundant experts in this model.""" -def is_mixture_of_experts(model: object) -> TypeIs[IsMixtureOfExperts]: - return isinstance(model, IsMixtureOfExperts) +def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: + return isinstance(model, MixtureOfExperts) @runtime_checkable From f7ba1624b58d62e18559bf29858b3d1dcb57a04d Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 20 May 2025 18:11:58 +0000 Subject: [PATCH 09/57] [Bugfix] Copy expert mappings after rearrangement Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 8e455d81a730..11b63815d205 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -188,8 +188,10 @@ def rearrange(self, model: MixtureOfExperts) -> None: Rearrange the experts according to the current load. """ - ep_group = get_ep_group() + ep_group = get_ep_group().device_group ep_rank = ep_group.rank() + + time_start = None is_main_rank = ep_rank == 0 if is_main_rank: torch.cuda.synchronize() @@ -248,7 +250,12 @@ def rearrange(self, model: MixtureOfExperts) -> None: ep_group, ) + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + self.logical_to_physical_map.copy_(new_logical_to_physical_map) + self.logical_replica_count.copy_(new_logical_replica_count) + if is_main_rank: + assert time_start is not None torch.cuda.synchronize() time_end = time.perf_counter() logger.info( From ba3d60fe4de679b9ebe31d9fa03574b23d6197ed Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 23 May 2025 20:19:09 +0000 Subject: [PATCH 10/57] [Chore] Move implementations to `deepseek_v2.py` Signed-off-by: Bowen Wang --- vllm/model_executor/models/deepseek.py | 143 +++++++--------------- vllm/model_executor/models/deepseek_v2.py | 81 +++++++++++- 2 files changed, 119 insertions(+), 105 deletions(-) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 964588ad91e6..88d1ca9f7b83 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -22,7 +22,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -import typing from collections.abc import Iterable from typing import Any, Optional, Union @@ -31,13 +30,12 @@ from transformers import PretrainedConfig from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_ep_group from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -52,7 +50,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import MixtureOfExperts, SupportsPP +from .interfaces import SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -96,69 +94,36 @@ class DeepseekMoE(nn.Module): def __init__( self, - layer_idx: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() - self.layer_idx = layer_idx self.config = config self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.ep_group = get_ep_group() - self.ep_rank = self.ep_group.rank - self.ep_size = self.ep_group.world_size - self.n_routed_experts: int = config.n_routed_experts - self.n_shared_experts: int = config.n_shared_experts + self.n_routed_experts = config.n_routed_experts self.top_k = config.num_experts_per_tok if self.tp_size > self.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.n_routed_experts}.") - # Load balancing settings. - # Currently, `n_redundant_experts` equals to `n_extra_experts`. - vllm_config = get_current_vllm_config() - self.n_extra_experts = vllm_config.parallel_config.num_extra_experts - self.n_physical_experts = self.n_routed_experts + self.n_extra_experts - self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = (self.n_physical_experts - - self.n_logical_experts) - self.n_local_physical_experts = self.n_physical_experts // self.ep_size - - self.physical_expert_start = (self.ep_rank * - self.n_local_physical_experts) - self.physical_expert_end = (self.physical_expert_start + - self.n_local_physical_experts) + self.experts = nn.ModuleList([ + DeepseekMLP(hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False) + for idx in range(self.n_routed_experts) + ]) + self.pack_params() self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, quant_config=None) - if config.topk_method == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts)) - else: - self.gate.e_score_correction_bias = typing.cast(Any, None) - - self.experts = FusedMoE( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_groups, - topk_group=config.topk_group, - prefix=f"{prefix}.experts", - scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias, - ) - if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -167,28 +132,27 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=self.experts.must_reduce_shared_expert_outputs( - ), - prefix=f"{prefix}.shared_experts", + reduce_results=False, ) - def get_weights(self) -> list[torch.Tensor]: - ret: list[torch.Tensor] = [] - for weight in [self.gate_proj_weight, self.down_proj_weight]: - weight = typing.cast( - Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], weight) - if isinstance(weight, torch.Tensor): - assert weight.is_contiguous() - ret.append(weight.view(self.n_local_physical_experts, -1)) - else: - # FP8 weights - assert weight[0].element_size() == 1 - assert weight[0].is_contiguous() - assert weight[1].is_contiguous() - ret.append(weight[0].view(torch.int8).view( - self.n_local_physical_experts, -1)) - ret.append(weight[1].view(self.n_local_physical_experts, -1)) - return ret + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -197,9 +161,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + shared_output @@ -320,8 +288,7 @@ def __init__( if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(layer_idx=layer_idx, - config=config, + self.mlp = DeepseekMoE(config=config, quant_config=quant_config, prefix=f"{prefix}.mlp") else: @@ -470,7 +437,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class DeepseekForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): +class DeepseekForCausalLM(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -488,22 +455,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self.expert_weights = [] - - # Set MoE hyperparameters - # TODO(bowen): Add support for MTP layers - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) - self.num_expert_groups = config.n_groups - - example_moe = typing.cast( - DeepseekMoE, self.model.layers[config.num_hidden_layers - 1].mlp) - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -531,14 +482,4 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - loaded_weights = loader.load_weights(weights) - - # Register the expert weights. - for layer in self.model.layers: - assert isinstance(layer, DeepseekDecoderLayer) - if isinstance(layer.mlp, DeepseekMoE): - self.expert_weights.append(layer.mlp.get_weights()) - - # TODO(bowen): Add support for MTP layers - - return loaded_weights + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b78c193c1345..e67d486a1d92 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" +import typing from collections.abc import Iterable from typing import Any, Optional, Union @@ -31,7 +32,8 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -42,6 +44,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -50,7 +53,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import MixtureOfExperts, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -95,14 +98,21 @@ class DeepseekV2MoE(nn.Module): def __init__( self, + layer_idx: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() + self.layer_idx = layer_idx self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor - self.n_shared_experts = config.n_shared_experts + + self.ep_group = get_pp_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " @@ -119,6 +129,35 @@ def __init__( else: self.gate.e_score_correction_bias = None + # Load balancing settings. + # Currently, `n_redundant_experts` equals to `n_extra_experts`. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = parallel_config.enable_eplb + + if self.enable_eplb and not isinstance(quant_config, Fp8Config): + # TODO(bowen): Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API design + # causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError("EPLB is only supported for FP8 " + "quantization for now.") + + self.n_extra_experts = parallel_config.num_extra_experts + self.n_physical_experts = self.n_routed_experts + self.n_extra_experts + self.n_logical_experts = self.n_routed_experts + self.n_redundant_experts = (self.n_physical_experts - + self.n_logical_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + self.experts = FusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -180,6 +219,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(num_tokens, hidden_dim) + def get_expert_weights(self) -> Iterable[torch.Tensor]: + weights = self.experts.parameters() + assert all(weight.is_contiguous() for weight in weights) + return [ + weight.view(self.n_local_physical_experts, -1) + for weight in weights + ] + def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math @@ -539,6 +586,7 @@ def __init__( and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): self.mlp = DeepseekV2MoE( + layer_idx=self.layer_idx, config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", @@ -680,7 +728,7 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -699,6 +747,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + # TODO(bowen): Add support for MTP layers + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + example_moe = typing.cast( + DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -823,6 +887,15 @@ def load_weights(self, weights: Iterable[tuple[str, default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) + + # Register the expert weights. + for layer in self.model.layers: + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + self.expert_weights.append(layer.mlp.get_expert_weights()) + + # TODO(bowen): Add support for MTP layers + return loaded_params From ebcfcc769785cc1293fc449dafdfe3295ead4711 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 23 May 2025 20:20:23 +0000 Subject: [PATCH 11/57] [Chore] Remove expert load stats from forward context Moved into `FusedMoE` layers Signed-off-by: Bowen Wang --- vllm/forward_context.py | 9 ++------- vllm/v1/worker/gpu_model_runner.py | 8 +------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index fa211a14aa69..5d2d95f18d2f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -46,9 +46,6 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None - # Used for EPLB to collect expert load metrics. - # TODO(bowen): see if we can find a better accommodation for this - expert_load_pass: Optional[torch.Tensor] = None _forward_context: Optional[ForwardContext] = None @@ -66,8 +63,7 @@ def get_forward_context() -> ForwardContext: def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0, - expert_load_pass: Optional[torch.Tensor] = None): + num_tokens: int = 0): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -107,8 +103,7 @@ def set_forward_context(attn_metadata: Any, static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, - dp_metadata=dp_metadata, - expert_load_pass=expert_load_pass) + dp_metadata=dp_metadata) try: yield diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1e92ba228240..c35b77ee5e76 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1719,15 +1719,9 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - if self.parallel_config.enable_eplb: - expert_load_pass = self.eplb_state.expert_load_pass - else: - expert_load_pass = None - with set_forward_context(attn_metadata, self.vllm_config, - num_tokens=num_tokens, - expert_load_pass=expert_load_pass): + num_tokens=num_tokens): outputs = model( input_ids=input_ids, positions=positions, From 620f59a5818c72bad2c6477340fe2baed1063b4f Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 23 May 2025 21:22:24 +0000 Subject: [PATCH 12/57] [Feature] Weight loading for redundant experts Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 29 ++++- vllm/model_executor/layers/fused_moe/layer.py | 113 +++++++++++++++--- .../model_executor/layers/quantization/fp8.py | 6 + 3 files changed, 131 insertions(+), 17 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 11b63815d205..8ced2dd3f642 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -4,6 +4,7 @@ """ import time +from collections.abc import Sequence, Set from dataclasses import dataclass import torch @@ -75,10 +76,15 @@ class EplbState: def build_initial_global_physical_to_logical_map( num_routed_experts: int, num_redundant_experts: int, - ) -> list[int]: + ) -> Sequence[int]: """ Build an initial expert arrangement using the following structure: [original routed experts, redundant experts] + + Returns: + physical_to_logical_map (Sequence[int]): A list of integers, + where each integer is the index of the logical expert + that the corresponding physical expert maps to. """ global_physical_to_logical_map = list(range(num_routed_experts)) global_physical_to_logical_map += [ @@ -86,6 +92,27 @@ def build_initial_global_physical_to_logical_map( ] return global_physical_to_logical_map + @staticmethod + def build_initial_global_logical_to_physical_map( + num_routed_experts: int, + num_redundant_experts: int, + ) -> Sequence[Set[int]]: + """ + Build an initial expert arrangement using the following structure: + [original routed experts, redundant experts] + + Returns: + logical_to_physical_map (Sequence[set[int]]): A list of sets, + where each set contains the indices of the physical experts + that map to the corresponding logical expert. + """ + global_logical_to_physical_map = [{i} + for i in range(num_routed_experts)] + for i in range(num_redundant_experts): + global_logical_to_physical_map[i % num_routed_experts].add( + i + num_routed_experts) + return global_logical_to_physical_map + @classmethod def build( cls, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1cb77f64eae..fe436e7003d5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -18,6 +18,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed.eplb.states import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp @@ -271,6 +272,9 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError @@ -738,6 +742,9 @@ class FusedMoE(torch.nn.Module): reduce_results: Whether to all all_reduce on the output of the layer renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. + enable_eplb: Whether to enable expert parallelism load balancer. + logical_to_physical_map: Logical to physical expert mapping. + logical_replica_count: Count of physical replicas for logical experts. """ def __init__( @@ -762,6 +769,8 @@ def __init__( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + num_redundant_experts: int = 0, ): super().__init__() @@ -789,12 +798,26 @@ def __init__( compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.enable_eplb = enable_eplb + # Determine expert maps if self.use_ep: + global_physical_experts = num_experts + num_redundant_experts + if self.enable_eplb: + assert global_physical_experts % self.ep_size == 0, \ + "EPLB currently only supports even distribution of " \ + "experts across ranks." + self.initial_global_logical_to_physical_map = \ + EplbState.build_initial_global_logical_to_physical_map( + num_experts, num_redundant_experts,) + '''Used in initial weight loading only in EPLB.''' + else: + assert num_redundant_experts == 0, \ + "Redundant experts are only supported with EPLB." self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, - global_num_experts=self.global_num_experts) + global_num_experts=global_physical_experts) else: self.local_num_experts, self.expert_map = (self.global_num_experts, None) @@ -1031,10 +1054,30 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + if self.enable_eplb: + # `expert_id` is logical; with redundant experts, + # we need to convert it to physical id + global_physical_ids = self.initial_global_logical_to_physical_map[ + expert_id] + for global_physical_id in global_physical_ids: + local_physical_id = \ + self._map_global_expert_id_to_local_expert_id( + global_physical_id) + if local_physical_id != -1: + # Found a local replica of this logical expert + expert_id = local_physical_id + break + else: + # All of this logical expert's physical replica + # is not in our local space; skip loading + return + else: + expert_id = self._map_global_expert_id_to_local_expert_id( + expert_id) + if expert_id == -1: + return + # Hereafter, `expert_id` is local physical id - expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) - if expert_id == -1: - return quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format @@ -1183,20 +1226,43 @@ def weight_loader(self, param: torch.nn.Parameter, return @staticmethod - def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None): + def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + indices_type: Optional[torch.dtype] = None, + enable_eplb: bool = False, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Route the input hidden states to the top-k experts based on the + router logits. + + Returns: + (topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]): + The weights and *global physical* expert ids of the top-k experts. + + **Compatibility**: When EPLB is not enabled, the returned ids are + equivalent to global logical ids, so should be compatible with + plain MoE implementations without redundant experts. + """ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - # DeekSeekv2 uses grouped_top_k + if enable_eplb: + assert logical_to_physical_map is not None + assert logical_replica_count is not None + # TODO(bowen): come back soon + raise NotImplementedError + + # DeepSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None @@ -1272,6 +1338,12 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] + eplb_kwargs = { + "enable_eplb": True, + "logical_to_physical_map": self.logical_to_physical_map, + "logical_replica_count": self.logical_replica_count, + } if self.enable_eplb else {} + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1288,6 +1360,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, + **eplb_kwargs, ) if not skip_result_store: @@ -1323,6 +1396,13 @@ def forward_impl(self, hidden_states: torch.Tensor, if self.dp_size > 1: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) + + eplb_kwargs = { + "enable_eplb": True, + "logical_to_physical_map": self.logical_to_physical_map, + "logical_replica_count": self.logical_replica_count, + } if self.enable_eplb else {} + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1340,6 +1420,7 @@ def forward_impl(self, hidden_states: torch.Tensor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + **eplb_kwargs, ) if self.dp_size > 1: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f4cdc3db1a0d..cd7683b71df4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -833,6 +833,9 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -845,6 +848,9 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + enable_eplb=enable_eplb, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) if self.rocm_aiter_moe_enabled: From 90f3ed58af00fa17d15a4dfc8ca9e052fba3f7e4 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 04:26:22 +0000 Subject: [PATCH 13/57] [Feature] Expert replica selection and load metrics recording Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 41 +++------ vllm/model_executor/layers/fused_moe/layer.py | 85 +++++++++++++++++-- .../model_executor/layers/quantization/fp8.py | 8 ++ vllm/model_executor/models/deepseek_v2.py | 50 +++++------ vllm/model_executor/models/interfaces.py | 20 +++++ 5 files changed, 142 insertions(+), 62 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 8ced2dd3f642..294b9cfa6d86 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -37,7 +37,7 @@ class EplbState: This is a sparse matrix, where -1 indicates no mapping. - Shape: (num_moe_layers, num_logical_experts) + Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1) """ logical_replica_count: torch.Tensor """ @@ -51,13 +51,13 @@ class EplbState: Expert load during this forward pass. We use the token count each expert processes as the load. - Shape: (num_moe_layers, num_local_physical_experts) + Shape: (num_moe_layers, num_logical_experts) """ expert_load_window: torch.Tensor """ A sliding window of expert load. - Shape: (window_size, num_moe_layers, num_local_physical_experts) + Shape: (window_size, num_moe_layers, num_logical_experts) """ expert_load_window_step: int = 0 """Current step in the sliding window.""" @@ -140,6 +140,7 @@ def build( logical_replica_count = torch.zeros( (model.num_logical_experts, ), device=device, + dtype=torch.long, ) for i in range(model.num_physical_experts): @@ -164,13 +165,13 @@ def build( ).contiguous() expert_load_pass = torch.zeros( - (model.num_moe_layers, model.num_local_physical_experts), + (model.num_moe_layers, model.num_logical_experts), device=device, ) expert_load_window_size = parallel_config.eplb_window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, - model.num_local_physical_experts), + model.num_logical_experts), device=device, ) @@ -179,6 +180,12 @@ def build( expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) + model.set_eplb_state( + expert_load_pass, + logical_to_physical_map, + logical_replica_count, + ) + return EplbState( physical_to_logical_map, logical_to_physical_map, @@ -225,32 +232,12 @@ def rearrange(self, model: MixtureOfExperts) -> None: time_start = time.perf_counter() logger.info("Rearranging experts...") - window_size, num_moe_layers, num_local_physical_experts = ( - self.expert_load_window.shape) - num_physical_experts = model.num_physical_experts - - local_expert_start = ep_rank * num_local_physical_experts - local_expert_end = local_expert_start + num_local_physical_experts - local_physical_to_logical_map = self.physical_to_logical_map[:, - local_expert_start: - local_expert_end] - device = local_physical_to_logical_map.device - # Perform all-reduce to get the expert load across all ranks - expert_load_window = self.expert_load_window - global_expert_load_window = torch.zeros(window_size, - num_moe_layers, - num_physical_experts, - device=device) - global_expert_load_window.scatter_add_( - -1, - local_physical_to_logical_map.expand_as(expert_load_window), - expert_load_window, - ) + global_expert_load_window = self.expert_load_window.clone() all_reduce(global_expert_load_window, group=ep_group) # TODO(bowen): Treat differently for prefill and decode nodes - num_replicas = num_physical_experts + num_replicas = model.num_physical_experts num_groups = model.num_expert_groups # TODO(bowen): Remove magic numbers num_nodes = (ep_group.size() + 7) // 8 diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fe436e7003d5..4b5bb66f3f09 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -3,6 +3,7 @@ import importlib import threading from abc import abstractmethod +from collections.abc import Iterable from dataclasses import dataclass from enum import Enum from typing import Callable, Optional @@ -273,6 +274,7 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -743,8 +745,6 @@ class FusedMoE(torch.nn.Module): renomalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. enable_eplb: Whether to enable expert parallelism load balancer. - logical_to_physical_map: Logical to physical expert mapping. - logical_replica_count: Count of physical replicas for logical experts. """ def __init__( @@ -799,6 +799,9 @@ def __init__( self.layer_name = prefix self.enable_eplb = enable_eplb + self.expert_load_view: Optional[torch.Tensor] = None + self.logical_to_physical_map: Optional[torch.Tensor] = None + self.logical_replica_count: Optional[torch.Tensor] = None # Determine expert maps if self.use_ep: @@ -873,6 +876,20 @@ def __init__( assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method + if self.enable_eplb: + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8MoEMethod) + if not isinstance(quant_method, Fp8MoEMethod): + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError("EPLB is only supported for FP8 " + "quantization for now.") + if prepare_finalize is not None: world_size = moe.ep_size dp_size = int(moe.ep_size // moe.dp_size) @@ -1225,6 +1242,28 @@ def weight_loader(self, param: torch.nn.Parameter, tp_rank=self.tp_rank) return + def get_expert_weights(self) -> Iterable[torch.Tensor]: + weights = self.parameters() + assert all(weight.is_contiguous() for weight in weights) + return [weight.view(self.local_num_experts, -1) for weight in weights] + + def set_eplb_state( + self, + moe_layer_idx: int, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + """ + Register the EPLB state in this layer. + + This is used later in forward pass, where we get the expert mapping + and record the load metrics in `expert_load_view`. + """ + self.expert_load_view = expert_load_view[moe_layer_idx] + self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] + self.logical_replica_count = logical_replica_count[moe_layer_idx] + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -1239,6 +1278,7 @@ def select_experts( e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1256,12 +1296,6 @@ def select_experts( """ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk - if enable_eplb: - assert logical_to_physical_map is not None - assert logical_replica_count is not None - # TODO(bowen): come back soon - raise NotImplementedError - # DeepSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None @@ -1294,6 +1328,39 @@ def select_experts( if indices_type is not None: topk_ids = topk_ids.to(dtype=indices_type) + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + + # 1. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # (num_logical_experts,) + expert_load_view += topk_ids.bincount( + minlength=expert_load_view.shape[0], ) + + # 2. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + replica_indices = ( + torch.rand_like(topk_ids, dtype=torch.float) * + logical_replica_count[topk_ids]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: @@ -1340,6 +1407,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): eplb_kwargs = { "enable_eplb": True, + "expert_load_view": self.expert_load_view, "logical_to_physical_map": self.logical_to_physical_map, "logical_replica_count": self.logical_replica_count, } if self.enable_eplb else {} @@ -1399,6 +1467,7 @@ def forward_impl(self, hidden_states: torch.Tensor, eplb_kwargs = { "enable_eplb": True, + "expert_load_view": self.expert_load_view, "logical_to_physical_map": self.logical_to_physical_map, "logical_replica_count": self.logical_replica_count, } if self.enable_eplb else {} diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cd7683b71df4..4844d6eb0436 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -834,9 +834,16 @@ def apply( apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -849,6 +856,7 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, enable_eplb=enable_eplb, + expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e67d486a1d92..19f081359d15 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -44,7 +44,6 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -98,13 +97,11 @@ class DeepseekV2MoE(nn.Module): def __init__( self, - layer_idx: int, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() - self.layer_idx = layer_idx self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor @@ -135,17 +132,6 @@ def __init__( parallel_config = vllm_config.parallel_config self.enable_eplb = parallel_config.enable_eplb - if self.enable_eplb and not isinstance(quant_config, Fp8Config): - # TODO(bowen): Add support for additional quantization methods. - # The implementation for other quantization methods does not - # contain essential differences, but the current quant API design - # causes duplicated work when extending to new - # quantization methods, so I'm leaving it for now. - # If you plan to add support for more quantization methods, - # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError("EPLB is only supported for FP8 " - "quantization for now.") - self.n_extra_experts = parallel_config.num_extra_experts self.n_physical_experts = self.n_routed_experts + self.n_extra_experts self.n_logical_experts = self.n_routed_experts @@ -219,14 +205,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(num_tokens, hidden_dim) - def get_expert_weights(self) -> Iterable[torch.Tensor]: - weights = self.experts.parameters() - assert all(weight.is_contiguous() for weight in weights) - return [ - weight.view(self.n_local_physical_experts, -1) - for weight in weights - ] - def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: import math @@ -586,7 +564,6 @@ def __init__( and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): self.mlp = DeepseekV2MoE( - layer_idx=self.layer_idx, config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", @@ -755,6 +732,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.first_k_dense_replace) self.num_expert_groups = config.n_group + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + # TODO(bowen): Add support for MTP layers + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + self.moe_layers.append(layer.mlp.experts) + example_moe = typing.cast( DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp) self.num_logical_experts = example_moe.n_logical_experts @@ -764,6 +748,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_shared_experts = example_moe.n_shared_experts self.num_redundant_experts = example_moe.n_redundant_experts + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -889,10 +887,8 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params.add(name) # Register the expert weights. - for layer in self.model.layers: - assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): - self.expert_weights.append(layer.mlp.get_expert_weights()) + for layer in self.moe_layers: + self.expert_weights.append(layer.get_expert_weights()) # TODO(bowen): Add support for MTP layers diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index afa813ee89cc..8fcbc624a6a7 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -462,6 +462,26 @@ class MixtureOfExperts(Protocol): num_redundant_experts: int """Number of redundant experts in this model.""" + def set_eplb_state( + self, + expert_load_view: Tensor, + logical_to_physical_map: Tensor, + logical_replica_count: Tensor, + ) -> None: + """ + Register the EPLB state in the MoE model. + + Since these are views of the actual EPLB state, any changes made by + the EPLB algorithm are automatically reflected in the model's behavior + without requiring additional method calls to set new states. + + Args: + expert_load_view: A view of the expert load metrics tensor. + logical_to_physical_map: Mapping from logical to physical experts. + logical_replica_count: Count of replicas for each logical expert. + """ + ... + def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: return isinstance(model, MixtureOfExperts) From b3697dece7d0849628c9b5cdbba5b8a4900a4967 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 06:12:18 +0000 Subject: [PATCH 14/57] [Feature] Map logical experts in weight loading Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 23 +------- vllm/model_executor/layers/fused_moe/layer.py | 57 ++++++++----------- vllm/model_executor/models/deepseek_v2.py | 7 ++- 3 files changed, 30 insertions(+), 57 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 294b9cfa6d86..8d612809bb73 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -4,7 +4,7 @@ """ import time -from collections.abc import Sequence, Set +from collections.abc import Sequence from dataclasses import dataclass import torch @@ -92,27 +92,6 @@ def build_initial_global_physical_to_logical_map( ] return global_physical_to_logical_map - @staticmethod - def build_initial_global_logical_to_physical_map( - num_routed_experts: int, - num_redundant_experts: int, - ) -> Sequence[Set[int]]: - """ - Build an initial expert arrangement using the following structure: - [original routed experts, redundant experts] - - Returns: - logical_to_physical_map (Sequence[set[int]]): A list of sets, - where each set contains the indices of the physical experts - that map to the corresponding logical expert. - """ - global_logical_to_physical_map = [{i} - for i in range(num_routed_experts)] - for i in range(num_redundant_experts): - global_logical_to_physical_map[i % num_routed_experts].add( - i + num_routed_experts) - return global_logical_to_physical_map - @classmethod def build( cls, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4b5bb66f3f09..fdca7fe601ab 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -787,7 +787,7 @@ def __init__( get_dp_group().world_size), vllm_parallel_config=vllm_config.parallel_config)) - self.global_num_experts = num_experts + self.global_num_experts = num_experts + num_redundant_experts # For smuggling this layer into the fused moe custom op self.use_direct_call = self.dp_size == 1 @@ -810,10 +810,6 @@ def __init__( assert global_physical_experts % self.ep_size == 0, \ "EPLB currently only supports even distribution of " \ "experts across ranks." - self.initial_global_logical_to_physical_map = \ - EplbState.build_initial_global_logical_to_physical_map( - num_experts, num_redundant_experts,) - '''Used in initial weight loading only in EPLB.''' else: assert num_redundant_experts == 0, \ "Redundant experts are only supported with EPLB." @@ -1071,28 +1067,9 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: - if self.enable_eplb: - # `expert_id` is logical; with redundant experts, - # we need to convert it to physical id - global_physical_ids = self.initial_global_logical_to_physical_map[ - expert_id] - for global_physical_id in global_physical_ids: - local_physical_id = \ - self._map_global_expert_id_to_local_expert_id( - global_physical_id) - if local_physical_id != -1: - # Found a local replica of this logical expert - expert_id = local_physical_id - break - else: - # All of this logical expert's physical replica - # is not in our local space; skip loading - return - else: - expert_id = self._map_global_expert_id_to_local_expert_id( - expert_id) - if expert_id == -1: - return + expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) + if expert_id == -1: + return # Hereafter, `expert_id` is local physical id quant_method_name = self.quant_method.__class__.__name__ @@ -1348,8 +1325,8 @@ def select_experts( # to achieve better efficiency. # (num_logical_experts,) - expert_load_view += topk_ids.bincount( - minlength=expert_load_view.shape[0], ) + expert_load_view += topk_ids.flatten().bincount( + minlength=expert_load_view.shape[0]) # 2. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert @@ -1504,16 +1481,30 @@ def forward_impl(self, hidden_states: torch.Tensor, @classmethod def make_expert_params_mapping( - cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, ckpt_up_proj_name: str, - num_experts: int) -> list[tuple[str, str, int, str]]: + num_experts: int, + num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]: + + num_physical_experts = num_experts + num_redundant_experts + + # In the returned mapping: + # - `expert_id` is the physical expert id + # - `weight_name` contains the weight name of the logical expert + # So that we should map the expert id to logical in `weight_name` + physical_to_logical_map = \ + EplbState.build_initial_global_physical_to_logical_map( + num_experts, num_redundant_experts) return [ # (param_name, weight_name, expert_id, shard_id) ("experts.w13_" if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, weight_name in [ + f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.", + expert_id, shard_id) for expert_id in range(num_physical_experts) + for shard_id, weight_name in [ ("w1", ckpt_gate_proj_name), ("w2", ckpt_down_proj_name), ("w3", ckpt_up_proj_name), diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 19f081359d15..4a14842bd11a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -157,7 +157,9 @@ def __init__( topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias) + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -813,7 +815,8 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) + num_experts=self.config.n_routed_experts, + num_redundant_experts=self.num_redundant_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() From 5d85f61aff1496979a4bd6c48e9eedfd3c4bb7c2 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 07:13:45 +0000 Subject: [PATCH 15/57] [Bugfix] Use `scatter_add_` instead of `bincount` for compile Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 2 ++ vllm/model_executor/layers/fused_moe/layer.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 8d612809bb73..a205cf051130 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -145,12 +145,14 @@ def build( expert_load_pass = torch.zeros( (model.num_moe_layers, model.num_logical_experts), + dtype=torch.int32, device=device, ) expert_load_window_size = parallel_config.eplb_window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, model.num_logical_experts), + dtype=torch.int32, device=device, ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fdca7fe601ab..7b26589230ef 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1324,9 +1324,19 @@ def select_experts( # to the modular kernel, we can move this logic there # to achieve better efficiency. - # (num_logical_experts,) - expert_load_view += topk_ids.flatten().bincount( - minlength=expert_load_view.shape[0]) + # `expert_load_view`: (num_logical_experts,) + + # Should be equivalent to: + # ``` + # expert_load_view += topk_ids.flatten().bincount( + # minlength=expert_load_view.shape[0]) + # ``` + # We use `scatter_add_` since `bincount` cannot be compiled + topk_ids_flatten = topk_ids.flatten() + expert_load_view.scatter_add_( + dim=0, + index=topk_ids_flatten.long(), + src=torch.ones_like(topk_ids_flatten)) # 2. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert From e416e3cf82fbc87b33e86bc3ef9a930cc0d16e36 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 07:42:28 +0000 Subject: [PATCH 16/57] [Bugfix] Add EPLB args in `EngineArgs` Signed-off-by: Bowen Wang --- vllm/engine/arg_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 37d876c37419..fbb2462e4ca4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1080,6 +1080,8 @@ def create_engine_config( enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, num_extra_experts=self.num_extra_experts, + eplb_window_size=self.eplb_window_size, + eplb_step_interval=self.eplb_step_interval, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, From 233741c58673b0f920d1ac79984aa0b0ada5196f Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 07:50:23 +0000 Subject: [PATCH 17/57] [Bugfix] Sum up steps on EPLb rearrange Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index a205cf051130..731cdd34e5cf 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -214,7 +214,7 @@ def rearrange(self, model: MixtureOfExperts) -> None: logger.info("Rearranging experts...") # Perform all-reduce to get the expert load across all ranks - global_expert_load_window = self.expert_load_window.clone() + global_expert_load_window = self.expert_load_window.sum(dim=0) all_reduce(global_expert_load_window, group=ep_group) # TODO(bowen): Treat differently for prefill and decode nodes From cfcd42cf9573b950d62d051cb92e1411894d845b Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 08:06:09 +0000 Subject: [PATCH 18/57] [Bugfix] Collect expert weights into a list Signed-off-by: Bowen Wang --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 7b26589230ef..aabd2a0fe1f2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1220,7 +1220,7 @@ def weight_loader(self, param: torch.nn.Parameter, return def get_expert_weights(self) -> Iterable[torch.Tensor]: - weights = self.parameters() + weights = list(self.parameters()) assert all(weight.is_contiguous() for weight in weights) return [weight.view(self.local_num_experts, -1) for weight in weights] From 36b0b1182d5abf923c0f0d97bce37eebbc4f8df0 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 08:12:31 +0000 Subject: [PATCH 19/57] [Bugfix] Fix typo in assertion Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance_execute.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 57fbf2b287f2..3c3b0bf7a89e 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -5,9 +5,8 @@ This involves the exchange of expert weights between GPUs. """ -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, MutableSequence, Sequence from functools import partial -from typing import Dict, List, MutableSequence, Tuple import torch from torch.distributed import (P2POp, ProcessGroup, batch_isend_irecv, @@ -51,7 +50,7 @@ def get_ep_ranks_with_expert( num_local_experts: int, old_indices: Sequence[int], new_indices: Sequence[int], -) -> Tuple[MutableSequence[int], MutableSequence[int]]: +) -> tuple[MutableSequence[int], MutableSequence[int]]: """ Get the ranks of the experts that need to be exchanged. @@ -71,8 +70,8 @@ def get_ep_ranks_with_expert( local_cnt=num_local_experts, ) - ranks_to_send: List[int] = [] - ranks_to_recv: List[int] = [] + ranks_to_send: list[int] = [] + ranks_to_recv: list[int] = [] for i, e in enumerate(old_indices): if e == idx: @@ -133,10 +132,10 @@ def shuffle_layer( expert_weights_buffer): buffer[dst].copy_(weight[src]) - p2p_ops: List[P2POp] = [] + p2p_ops: list[P2POp] = [] # 2. Initiate sending of weights. - experts_send_loc: Dict[int, int] = {} + experts_send_loc: dict[int, int] = {} for src in range(num_local_experts): expert = old_indices[local2global(src)] if expert in experts_send_loc: @@ -173,7 +172,7 @@ def shuffle_layer( ] # 3. Initiate receiving of weights. - experts_recv_loc: Dict[int, int] = {} + experts_recv_loc: dict[int, int] = {} for dst in range(num_local_experts): if is_received_locally[dst]: continue @@ -254,7 +253,7 @@ def rearrange_expert_weights_inplace( num_local_physical_experts = next(iter(expert_weights[0])).shape[0] assert new_global_expert_indices.shape == (num_moe_layers, - num_local_physical_experts) + num_physical_experts) ep_rank = ep_group.rank() ep_size = ep_group.size() From d5add3a6bdf3d49fb6537601861662c1eef4a150 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 09:08:59 +0000 Subject: [PATCH 20/57] [Bugfix] Pad `log2phy` magging in rebalance algo Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance_algo.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index 6a15d09b7d74..a27678f3ec00 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -8,13 +8,11 @@ [DeepSeek EPLB](https://github.com/deepseek-ai/eplb). """ -from typing import Tuple - import torch def balanced_packing(weight: torch.Tensor, - num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]: + num_packs: int) -> tuple[torch.Tensor, torch.Tensor]: """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. @@ -63,7 +61,7 @@ def balanced_packing(weight: torch.Tensor, def replicate_experts( weight: torch.Tensor, - num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. @@ -180,7 +178,7 @@ def rebalance_experts( num_groups: int, num_nodes: int, num_gpus: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. @@ -200,7 +198,7 @@ def rebalance_experts( logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert expert_count: [layers, num_logical_experts], number of physical - replicasfor each logical expert + replicas for each logical expert """ num_layers, num_logical_experts = weight.shape weight = weight.float().cpu() @@ -212,7 +210,8 @@ def rebalance_experts( # use global load-balance policy phy2log, phyrank, logcnt = rebalance_experts_hierarchical( weight, num_replicas, 1, 1, num_gpus) - maxlogcnt = logcnt.max().item() + num_redundant_experts = num_replicas - num_logical_experts + maxlogcnt = num_redundant_experts + 1 log2phy: torch.Tensor = torch.full( (num_layers, num_logical_experts, maxlogcnt), -1, From b00bdb9c2a729794454d810e1002a1c60a6186d5 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 10:31:32 +0000 Subject: [PATCH 21/57] [Bugfix] Fix EP group in `DeepseekV2MoE` Signed-off-by: Bowen Wang --- vllm/model_executor/models/deepseek_v2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 4a14842bd11a..251bc489515e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -34,7 +34,8 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -105,7 +106,7 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor - self.ep_group = get_pp_group().device_group + self.ep_group = get_ep_group().device_group self.ep_rank = self.ep_group.rank() self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts From c9cf2d4863f8ff65a46e44a30e4d0b860318af80 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 22:46:26 +0000 Subject: [PATCH 22/57] [Refactor] Use local physical ids in expert load collection This is to facilitize the calculation of the load of a rank Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 71 +++++++++++++++++-- vllm/model_executor/layers/fused_moe/layer.py | 11 ++- .../model_executor/layers/quantization/fp8.py | 1 + vllm/v1/worker/gpu_model_runner.py | 13 ++-- 4 files changed, 83 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 731cdd34e5cf..57eca8f9e036 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -8,7 +8,7 @@ from dataclasses import dataclass import torch -from torch.distributed import all_reduce +from torch.distributed import all_gather, all_reduce from vllm.config import ParallelConfig from vllm.distributed.parallel_state import get_ep_group @@ -51,13 +51,13 @@ class EplbState: Expert load during this forward pass. We use the token count each expert processes as the load. - Shape: (num_moe_layers, num_logical_experts) + Shape: (num_moe_layers, num_local_physical_experts) """ expert_load_window: torch.Tensor """ A sliding window of expert load. - Shape: (window_size, num_moe_layers, num_logical_experts) + Shape: (window_size, num_moe_layers, num_local_physical_experts) """ expert_load_window_step: int = 0 """Current step in the sliding window.""" @@ -144,14 +144,14 @@ def build( ).contiguous() expert_load_pass = torch.zeros( - (model.num_moe_layers, model.num_logical_experts), + (model.num_moe_layers, model.num_local_physical_experts), dtype=torch.int32, device=device, ) expert_load_window_size = parallel_config.eplb_window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, - model.num_logical_experts), + model.num_local_physical_experts), dtype=torch.int32, device=device, ) @@ -178,11 +178,41 @@ def build( expert_rearrangement_step_interval=eplb_step_interval, ) - def step(self, model: MixtureOfExperts) -> None: + def step(self, model: MixtureOfExperts) -> tuple[float, float, float]: """ Step the EPLB state. + + Returns: + (avg_tokens, max_tokens, balancedness) (tuple[float, float, float]): + The returned metrics are all summed up across layers. + - `avg_tokens`: The average load across ranks. + - `max_tokens`: The maximum load across ranks. + - `balancedness`: The ratio of average load to maximum load. """ + # Collect load metrics from all ranks + ep_group = get_ep_group().device_group + # (num_moe_layers,) + num_tokens = self.expert_load_pass.sum(dim=-1) + num_tokens_list = [ + torch.empty_like(num_tokens) for _ in range(ep_group.size()) + ] + all_gather(num_tokens_list, num_tokens, group=ep_group) + # Stack to get (num_ranks, num_moe_layers) + num_tokens_per_rank = torch.stack(num_tokens_list).float() + + # Compute balancedness ratio: + # for each layer: + # (mean load across ranks) / (max load across ranks) + avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0) + + # Just to make type checker happy + tokens_tensors: list[float] = torch.stack( + [avg_tokens_tensor, max_tokens_tensor]).tolist() + avg_tokens, max_tokens = tokens_tensors + balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 + # Update the expert load sliding window self.expert_load_window[self.expert_load_window_step] = ( self.expert_load_pass.clone()) @@ -198,6 +228,8 @@ def step(self, model: MixtureOfExperts) -> None: self.expert_rearrangement_step = 0 self.rearrange(model) + return avg_tokens, max_tokens, balancedness + def rearrange(self, model: MixtureOfExperts) -> None: """ Rearrange the experts according to the current load. @@ -213,8 +245,33 @@ def rearrange(self, model: MixtureOfExperts) -> None: time_start = time.perf_counter() logger.info("Rearranging experts...") + # This mapping is only used here, so we do not store it in the state + physical_expert_start = ep_rank * model.num_local_physical_experts + physical_expert_end = (physical_expert_start + + model.num_local_physical_experts) + # (num_moe_layers, num_local_physical_experts) + local_physical_to_logical_map = self.physical_to_logical_map[ + :, + physical_expert_start:physical_expert_end, + ] + + # Map the local physical expert load to global logical experts + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=local_physical_to_logical_map.unsqueeze(0).expand_as( + self.expert_load_window).long(), + src=self.expert_load_window, + ) + # Perform all-reduce to get the expert load across all ranks - global_expert_load_window = self.expert_load_window.sum(dim=0) + global_expert_load_window = logical_expert_load_window.sum(dim=0) all_reduce(global_expert_load_window, group=ep_group) # TODO(bowen): Treat differently for prefill and decode nodes diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index aabd2a0fe1f2..8cc93164592d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1255,6 +1255,7 @@ def select_experts( e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, enable_eplb: bool = False, + expert_map: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, @@ -1326,13 +1327,19 @@ def select_experts( # `expert_load_view`: (num_logical_experts,) + # Mask out non-local experts + if expert_map is not None: + topk_ids_local = expert_map[topk_ids] + topk_ids_flatten = topk_ids_local[topk_ids_local >= 0] + else: + topk_ids_flatten = topk_ids.flatten() + # Should be equivalent to: # ``` - # expert_load_view += topk_ids.flatten().bincount( + # expert_load_view += topk_ids_flatten.bincount( # minlength=expert_load_view.shape[0]) # ``` # We use `scatter_add_` since `bincount` cannot be compiled - topk_ids_flatten = topk_ids.flatten() expert_load_view.scatter_add_( dim=0, index=topk_ids_flatten.long(), diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4844d6eb0436..3bc1629b584e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -856,6 +856,7 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, enable_eplb=enable_eplb, + expert_map=expert_map, expert_load_view=expert_load_view, logical_to_physical_map=logical_to_physical_map, logical_replica_count=logical_replica_count, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c35b77ee5e76..12c1fd77b6c1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -23,7 +23,7 @@ has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, + get_ep_group, get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger @@ -1419,9 +1419,14 @@ def execute_model( # EPLB step if self.parallel_config.enable_eplb: - self.eplb_state.step(self.model) - - # TODO(bowen): Log balancedness + assert is_mixture_of_experts(self.model) + avg_tokens, max_tokens, balancedness = \ + self.eplb_state.step(self.model) + + if get_ep_group().is_first_rank: + logger.debug( + "Model step: avg_tokens=%.2f, max_tokens=%d, " + "balancedness=%.4f", avg_tokens, max_tokens, balancedness) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, From 4f79fef2f4569b753e8045ffe02094fe30434527 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 27 May 2025 23:46:35 +0000 Subject: [PATCH 23/57] [Bugfix] Map physical id before recording expert load metrics Signed-off-by: Bowen Wang --- vllm/model_executor/layers/fused_moe/layer.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8cc93164592d..3639fe5b2fdc 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1311,7 +1311,17 @@ def select_experts( assert logical_to_physical_map is not None assert logical_replica_count is not None - # 1. Record expert load metrics. + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + replica_indices = ( + torch.rand_like(topk_ids, dtype=torch.float) * + logical_replica_count[topk_ids]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. # TODO(bowen): When using `FusedMoEModularKernel`, this # can be done in a more unified way, since @@ -1345,16 +1355,6 @@ def select_experts( index=topk_ids_flatten.long(), src=torch.ones_like(topk_ids_flatten)) - # 2. Convert the logical expert ids to physical expert ids - # Directly select a random replica for each logical expert - replica_indices = ( - torch.rand_like(topk_ids, dtype=torch.float) * - logical_replica_count[topk_ids]).long().unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids].gather( - -1, replica_indices).squeeze(-1) - - topk_ids = physical_ids - return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: From a97ee39140ca176c369d2523a5964d4488360cd8 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 28 May 2025 08:37:43 +0000 Subject: [PATCH 24/57] [Perf] Reduce overhead of expert load recording Signed-off-by: Bowen Wang --- vllm/model_executor/layers/fused_moe/layer.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3639fe5b2fdc..a11b4830b912 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1340,20 +1340,30 @@ def select_experts( # Mask out non-local experts if expert_map is not None: topk_ids_local = expert_map[topk_ids] - topk_ids_flatten = topk_ids_local[topk_ids_local >= 0] + topk_ids_flatten = topk_ids_local.flatten() else: topk_ids_flatten = topk_ids.flatten() # Should be equivalent to: # ``` - # expert_load_view += topk_ids_flatten.bincount( + # topk_ids_masked = topk_ids_local[topk_ids_local >= 0] + # expert_load_view += topk_ids_masked.bincount( # minlength=expert_load_view.shape[0]) # ``` # We use `scatter_add_` since `bincount` cannot be compiled - expert_load_view.scatter_add_( - dim=0, - index=topk_ids_flatten.long(), - src=torch.ones_like(topk_ids_flatten)) + + # Performance optimization: + # `masked_fill` is significantly faster than `masked_select` + invalid_mask = topk_ids_flatten < 0 + # Replace invalid expert ids with 0 (just a dummy position) + # to avoid out-of-bounds errors in scatter_add_ + index = topk_ids_flatten.masked_fill_(invalid_mask, 0) + # `src` is the valid mask, which is 1 for valid and 0 for invalid + src = ~invalid_mask + + expert_load_view.scatter_add_(dim=0, + index=index.long(), + src=src.to(expert_load_view)) return topk_weights, topk_ids From 2b14d5149cf748881fac7169c669a680e88bb350 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 29 May 2025 23:12:25 +0000 Subject: [PATCH 25/57] [Bugfix] Step EPLB state in dummy run to avoid blocking DP Signed-off-by: Bowen Wang --- vllm/v1/worker/gpu_model_runner.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7f35b1c1f113..279ea51603eb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1117,6 +1117,22 @@ def sync_and_slice_intermediate_tensors( for k, v in self.intermediate_tensors.items() }) + def eplb_step(self) -> None: + """ + Step for the EPLB (Expert Parallelism Load Balancing) state. + """ + if not self.parallel_config.enable_eplb: + return + + assert is_mixture_of_experts(self.model) + avg_tokens, max_tokens, balancedness = \ + self.eplb_state.step(self.model) + + if get_ep_group().is_first_rank: + logger.debug( + "Model step: avg_tokens=%.2f, max_tokens=%d, " + "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + @torch.inference_mode() def execute_model( self, @@ -1441,16 +1457,7 @@ def execute_model( if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() - # EPLB step - if self.parallel_config.enable_eplb: - assert is_mixture_of_experts(self.model) - avg_tokens, max_tokens, balancedness = \ - self.eplb_state.step(self.model) - - if get_ep_group().is_first_rank: - logger.debug( - "Model step: avg_tokens=%.2f, max_tokens=%d, " - "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + self.eplb_step() return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1775,6 +1782,9 @@ def _dummy_run( assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) + # This is necessary to avoid blocking DP + self.eplb_step() + logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] From 306b21a3168d250c82b2c238755e52a9b6103e2b Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Fri, 30 May 2025 01:24:32 +0000 Subject: [PATCH 26/57] [Feature] Do not record expert loads for dummy batches Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 34 ++++++++++++++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 6 +++--- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index 57eca8f9e036..ac40ec668b89 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -178,10 +178,18 @@ def build( expert_rearrangement_step_interval=eplb_step_interval, ) - def step(self, model: MixtureOfExperts) -> tuple[float, float, float]: + def step(self, + model: MixtureOfExperts, + is_dummy: bool = False) -> tuple[float, float, float]: """ Step the EPLB state. + Args: + model (MixtureOfExperts): The MoE model. + is_dummy (bool): If `True`, this is a dummy step and the load + metrics recorded in this forward pass will not count. Defaults + to `False`. + Returns: (avg_tokens, max_tokens, balancedness) (tuple[float, float, float]): The returned metrics are all summed up across layers. @@ -190,10 +198,14 @@ def step(self, model: MixtureOfExperts) -> tuple[float, float, float]: - `balancedness`: The ratio of average load to maximum load. """ + if is_dummy: + # Do not record load metrics for dummy steps + self.expert_load_pass.zero_() + # `num_tokens`: (num_moe_layers,) + num_tokens = self.expert_load_pass.sum(dim=-1) + # Collect load metrics from all ranks ep_group = get_ep_group().device_group - # (num_moe_layers,) - num_tokens = self.expert_load_pass.sum(dim=-1) num_tokens_list = [ torch.empty_like(num_tokens) for _ in range(ep_group.size()) ] @@ -214,14 +226,18 @@ def step(self, model: MixtureOfExperts) -> tuple[float, float, float]: balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 # Update the expert load sliding window - self.expert_load_window[self.expert_load_window_step] = ( - self.expert_load_pass.clone()) - self.expert_load_window_step += 1 - if self.expert_load_window_step >= self.expert_load_window_size: - self.expert_load_window_step = 0 - self.expert_load_pass.zero_() + if not is_dummy: + self.expert_load_window[self.expert_load_window_step] = ( + self.expert_load_pass.clone()) + self.expert_load_window_step += 1 + if self.expert_load_window_step >= self.expert_load_window_size: + self.expert_load_window_step = 0 + self.expert_load_pass.zero_() # Step the expert rearrangement step + # Note that even if this is a dummy step, we still increment the + # rearrangement step and perform rearrangement to ensure all ranks are + # performing collective communication. self.expert_rearrangement_step += 1 if (self.expert_rearrangement_step >= self.expert_rearrangement_step_interval): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 279ea51603eb..ee12e050cd29 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1117,7 +1117,7 @@ def sync_and_slice_intermediate_tensors( for k, v in self.intermediate_tensors.items() }) - def eplb_step(self) -> None: + def eplb_step(self, is_dummy: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1126,7 +1126,7 @@ def eplb_step(self) -> None: assert is_mixture_of_experts(self.model) avg_tokens, max_tokens, balancedness = \ - self.eplb_state.step(self.model) + self.eplb_state.step(self.model, is_dummy) if get_ep_group().is_first_rank: logger.debug( @@ -1783,7 +1783,7 @@ def _dummy_run( self.drafter.dummy_run(num_tokens) # This is necessary to avoid blocking DP - self.eplb_step() + self.eplb_step(is_dummy=True) logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] From 021578e4913afeaf2d25b486c08ec61d490dd5ef Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 2 Jun 2025 14:07:19 -0700 Subject: [PATCH 27/57] [Bugfix] Collect expert weights after weight post-processing Instead collecting expert weights during weight loading, we collect it later after the post-processing, since some processing like quantization will make the original weights unavailable. Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance_execute.py | 2 +- vllm/model_executor/models/deepseek_v2.py | 10 ++++------ vllm/model_executor/models/interfaces.py | 4 ++++ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 3c3b0bf7a89e..c8bac63b30c2 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -167,7 +167,7 @@ def shuffle_layer( for dst in recv_ranks: dst_global = get_global_rank(ep_group, dst) p2p_ops += [ - P2POp(torch.distributed.isend, weight[dst], dst_global) + P2POp(torch.distributed.isend, weight[src], dst_global) for weight in expert_weights ] diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 251bc489515e..b983030e473e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -758,6 +758,8 @@ def set_eplb_state( logical_replica_count: torch.Tensor, ) -> None: for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) layer.set_eplb_state( moe_layer_idx=layer_idx, expert_load_view=expert_load_view, @@ -765,6 +767,8 @@ def set_eplb_state( logical_replica_count=logical_replica_count, ) + # TODO(bowen): Add support for MTP layers + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -890,12 +894,6 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) - # Register the expert weights. - for layer in self.moe_layers: - self.expert_weights.append(layer.get_expert_weights()) - - # TODO(bowen): Add support for MTP layers - return loaded_params diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index dd7ca38e1088..c18cbf4d8709 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -477,6 +477,10 @@ def set_eplb_state( the EPLB algorithm are automatically reflected in the model's behavior without requiring additional method calls to set new states. + You should also collect model's `expert_weights` here instead of in + the weight loader, since after initial weight loading, further + processing like quantization may be applied to the weights. + Args: expert_load_view: A view of the expert load metrics tensor. logical_to_physical_map: Mapping from logical to physical experts. From c2e051659f4c6b6b048b1067fc4a2d7213c93289 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 2 Jun 2025 19:10:48 -0700 Subject: [PATCH 28/57] [Bugfix] Fix weight loading of replica experts Signed-off-by: Bowen Wang --- vllm/distributed/eplb/states.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 41 +++++++++++++----- vllm/model_executor/models/deepseek_v2.py | 42 ++++++++++++++----- 3 files changed, 63 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index ac40ec668b89..d081fae052bc 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -167,7 +167,7 @@ def build( logical_replica_count, ) - return EplbState( + return cls( physical_to_logical_map, logical_to_physical_map, logical_replica_count, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c9dcddf15cc2..71ee66d7a589 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,7 +5,7 @@ from collections.abc import Iterable from dataclasses import dataclass from enum import Enum -from typing import Callable, Optional +from typing import Callable, Literal, Optional, overload import torch import torch.nn.functional as F @@ -1016,12 +1016,31 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: return expert_id return self.expert_map[expert_id].item() + @overload def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: + shard_id: str, expert_id: int, + return_success: Literal[False]) -> None: + ... + + @overload + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int, + return_success: Literal[True]) -> bool: + ... + + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + return_success: bool = False) -> Optional[bool]: expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: - return + # Failed to load this param since it's not local to this rank + return False if return_success else None # Hereafter, `expert_id` is local physical id quant_method_name = self.quant_method.__class__.__name__ @@ -1050,7 +1069,7 @@ def weight_loader(self, param: torch.nn.Parameter, if is_gguf_weight_type: param.weight_type = loaded_weight.item() param.data.copy_(loaded_weight) - return + return True if return_success else None # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors @@ -1089,7 +1108,7 @@ def weight_loader(self, param: torch.nn.Parameter, self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case g_idx if "g_idx" in weight_name: @@ -1098,7 +1117,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None if "ModelOpt" in quant_method_name: if ('weight_scale_2' in weight_name @@ -1114,7 +1133,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name @@ -1151,7 +1170,7 @@ def weight_loader(self, param: torch.nn.Parameter, else: raise ValueError( f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") - return + return True if return_success else None # Case weight_shape if "weight_shape" in weight_name: @@ -1159,7 +1178,7 @@ def weight_loader(self, param: torch.nn.Parameter, self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) - return + return True if return_success else None # Case model weights if "weight" in weight_name: @@ -1169,7 +1188,9 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=self.tp_rank) - return + return True if return_success else None + + return False if return_success else None def get_expert_weights(self) -> Iterable[torch.Tensor]: weights = list(self.parameters()) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b983030e473e..7b9a17e3fe07 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only DeepseekV2/DeepseekV3 model.""" import typing -from collections.abc import Iterable +from collections.abc import Callable, Iterable from typing import Any, Optional, Union import torch @@ -858,24 +858,44 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue From 79c0d4196eaefc931f7fc17769dc4d985b3123ca Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 9 Jun 2025 16:42:10 -0700 Subject: [PATCH 29/57] [Bugfix] Remove `e_score_correction_bias` in expert weights Signed-off-by: Bowen Wang --- vllm/model_executor/layers/fused_moe/layer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 8ab8d6c1214a..cf0ba7da3ebc 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1334,9 +1334,20 @@ def weight_loader(self, return False if return_success else None def get_expert_weights(self) -> Iterable[torch.Tensor]: - weights = list(self.parameters()) - assert all(weight.is_contiguous() for weight in weights) - return [weight.view(self.local_num_experts, -1) for weight in weights] + weights = list(self.named_parameters()) + assert all(weight.is_contiguous() for _, weight in weights) + + # Filter out the non-expert weights. + # `e_score_correction_bias` is a bias for each logical expert, + # with shape (num_logical_experts,), not an expert weight. + NON_EXPERT_WEIGHTS = { + "e_score_correction_bias", + } + + return [ + weight.view(self.local_num_experts, -1) for name, weight in weights + if name not in NON_EXPERT_WEIGHTS + ] def set_eplb_state( self, From b011065b8ebbe13f430ab6b12d25f7222043e6e5 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 9 Jun 2025 18:18:41 -0700 Subject: [PATCH 30/57] [Bugfix] Fix shapes and dtypes in `FusedMoE` Signed-off-by: Bowen Wang --- vllm/model_executor/layers/fused_moe/layer.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cf0ba7da3ebc..724e7e7e9ed9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -876,9 +876,8 @@ def __init__( # Determine expert maps if self.use_ep: - global_physical_experts = num_experts + num_redundant_experts if self.enable_eplb: - assert global_physical_experts % self.ep_size == 0, \ + assert self.global_num_experts % self.ep_size == 0, \ "EPLB currently only supports even distribution of " \ "experts across ranks." else: @@ -887,7 +886,7 @@ def __init__( self.local_num_experts, self.expert_map = determine_expert_map( ep_size=self.ep_size, ep_rank=self.ep_rank, - global_num_experts=global_physical_experts) + global_num_experts=self.global_num_experts) else: self.local_num_experts, self.expert_map = (self.global_num_experts, None) @@ -992,8 +991,9 @@ def __init__( dtype=act_dtype, device=torch.cuda.current_device()) + # Note here we use `num_experts` which is logical expert count self.batched_router_logits = torch.zeros( - (MOE_DP_CHUNK_SIZE, self.global_num_experts), + (MOE_DP_CHUNK_SIZE, num_experts), dtype=act_dtype, device=torch.cuda.current_device()) @@ -1438,10 +1438,17 @@ def select_experts( # 1. Convert the logical expert ids to physical expert ids # Directly select a random replica for each logical expert + + # TODO: maybe optimize this by using specified kernels, + # or compute pseudo-random indices by modulo + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() replica_indices = ( torch.rand_like(topk_ids, dtype=torch.float) * - logical_replica_count[topk_ids]).long().unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids].gather( + logical_replica_count[topk_ids_long]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( -1, replica_indices).squeeze(-1) topk_ids = physical_ids @@ -1490,6 +1497,8 @@ def select_experts( index=index.long(), src=src.to(expert_load_view)) + topk_ids = topk_ids.to(dtype=indices_type) + return topk_weights, topk_ids def must_reduce_shared_expert_outputs(self) -> bool: From 90706aab6cbd13c6b809d61d3f10ad7517dca1fc Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 16 Jun 2025 15:19:49 -0700 Subject: [PATCH 31/57] [Feature] Disable EPLb step during profile run Signed-off-by: Bowen Wang --- vllm/v1/worker/gpu_model_runner.py | 12 +++++++++--- vllm/v1/worker/gpu_worker.py | 4 ++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7a62d3da0dde..605050b56501 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1807,6 +1807,7 @@ def _dummy_run( self, num_tokens: int, skip_attn: bool = True, + skip_eplb: bool = False, ) -> torch.Tensor: # Padding for DP @@ -1903,7 +1904,8 @@ def _dummy_run( self.drafter.dummy_run(num_tokens) # This is necessary to avoid blocking DP - self.eplb_step(is_dummy=True) + if not skip_eplb: + self.eplb_step(is_dummy=True) logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] @@ -2085,8 +2087,12 @@ def capture_model(self) -> None: total=len(self.cudagraph_batch_sizes)): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, skip_attn=skip_attn) - self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, + skip_attn=skip_attn, + skip_eplb=True) + self._dummy_run(num_tokens, + skip_attn=skip_attn, + skip_eplb=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b7d244f27045..9e29176f1c2f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -256,7 +256,7 @@ def compile_or_warm_up_model(self) -> None: ] for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) - self.model_runner._dummy_run(size) + self.model_runner._dummy_run(size, skip_eplb=True) if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -270,7 +270,7 @@ def compile_or_warm_up_model(self) -> None: self.scheduler_config.max_num_batched_tokens) self.model_runner._dummy_sampler_run( hidden_states=self.model_runner._dummy_run( - num_tokens=max_num_reqs)) + num_tokens=max_num_reqs, skip_eplb=True)) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. From f1f62b23becf34abb39d659fc2b7e9a45e991dde Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 17 Jun 2025 01:41:21 -0700 Subject: [PATCH 32/57] [Bugfix] Synchronize CUDA before shuffling layer to avoid hang Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance_execute.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index c8bac63b30c2..18ae72e1446d 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -265,6 +265,9 @@ def rearrange_expert_weights_inplace( expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] for layer in range(num_moe_layers): + # NOTE(bowen): We need this synchronize to run, but I don't know why. + # If you figure out the reason, please let me know -- thank you! + torch.cuda.synchronize() shuffle_layer( num_local_physical_experts, ep_rank, From 993d7d7f6c857a0541c3fde174f4c871db0cac84 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 19 Jun 2025 15:12:05 -0700 Subject: [PATCH 33/57] [Style] Rename module `eplb.states` to `eplb.eplb_state` This is because the logger will display the filename. [INFO] [states.py] is kind of confusing. Signed-off-by: Bowen Wang --- vllm/distributed/eplb/__init__.py | 5 ++++- vllm/distributed/eplb/{states.py => eplb_state.py} | 0 vllm/distributed/eplb/rebalance_execute.py | 7 +++++-- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 5 files changed, 11 insertions(+), 5 deletions(-) rename vllm/distributed/eplb/{states.py => eplb_state.py} (100%) diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py index be6006b4ecb8..c87b039afd73 100644 --- a/vllm/distributed/eplb/__init__.py +++ b/vllm/distributed/eplb/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +''' +Expert parallelism load balancer (EPLB). +''' +from .eplb_state import * from .rebalance_algo import * -from .states import * diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/eplb_state.py similarity index 100% rename from vllm/distributed/eplb/states.py rename to vllm/distributed/eplb/eplb_state.py diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 18ae72e1446d..cad7f3f0abc8 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -167,8 +167,11 @@ def shuffle_layer( for dst in recv_ranks: dst_global = get_global_rank(ep_group, dst) p2p_ops += [ - P2POp(torch.distributed.isend, weight[src], dst_global) - for weight in expert_weights + P2POp( + torch.distributed.isend, + weight[src], + dst_global, + ) for weight in expert_weights ] # 3. Initiate receiving of weights. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 652adb770b42..fa30155f6501 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -21,7 +21,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.distributed.eplb.states import EplbState +from vllm.distributed.eplb.eplb_state import EplbState from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f21379d51bc7..2cbf761d5a80 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -20,7 +20,7 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) -from vllm.distributed.eplb.states import EplbState +from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 From 90afdaf1913c31bad5f7963139183903b12b49be Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 19 Jun 2025 21:32:11 -0700 Subject: [PATCH 34/57] [Feature] Run a dummy rearrangement during profile run for CUDA graphs I'm actually not very sure why we need this. When CUDA graph in on, we need to perform a communication during the profile run, maybe to init the NCCL buffer (even though this is outside the graph itself) to comply with memory constraints of CUDA graphs. Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance_execute.py | 24 +++++++++++++++++-- vllm/distributed/eplb/states.py | 28 ++++++++++++++++------ vllm/v1/worker/gpu_model_runner.py | 12 ++++++---- 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index c8bac63b30c2..1af5cef309ad 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -9,8 +9,8 @@ from functools import partial import torch -from torch.distributed import (P2POp, ProcessGroup, batch_isend_irecv, - get_global_rank) +from torch.distributed import (P2POp, ProcessGroup, all_gather, + batch_isend_irecv, get_global_rank) def idx_local_to_global( @@ -232,6 +232,7 @@ def rearrange_expert_weights_inplace( new_global_expert_indices: torch.Tensor, expert_weights: Sequence[Iterable[torch.Tensor]], ep_group: ProcessGroup, + is_profile: bool = False, ) -> None: """ Rearranges the expert weights in place according to the new expert indices. @@ -247,6 +248,9 @@ def rearrange_expert_weights_inplace( For example, a linear layer may have up and down projection, so weight_count = 2. Each weight's hidden size can be different. ep_group: The device process group for expert parallelism. + is_profile (bool): If `True`, do not perform any actual weight copy. + This is used during profile run, where we only perform dummy + communications to reserve enough memory for the buffers. """ num_moe_layers, num_physical_experts = old_global_expert_indices.shape assert len(expert_weights) == num_moe_layers @@ -264,6 +268,22 @@ def rearrange_expert_weights_inplace( # have the same shape. expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] + if is_profile: + # Maximum send size is to send all local experts to all ranks, + # So we use a dummy `all_gather` to reserve enough communication buffer + for weight, buffer in zip(expert_weights[0], expert_weights_buffer): + # A `/dev/null`-like buffer to avoid real memory allocation + dummy_recv_buffer = [buffer for _ in range(ep_size)] + # NOTE(bowen): Needed this barrier to avoid OOM during actual + # execution. I'm not very sure why this is needed + torch.distributed.barrier() + all_gather( + dummy_recv_buffer, + weight, + group=ep_group, + ) + return + for layer in range(num_moe_layers): shuffle_layer( num_local_physical_experts, diff --git a/vllm/distributed/eplb/states.py b/vllm/distributed/eplb/states.py index d081fae052bc..248c4adeade4 100644 --- a/vllm/distributed/eplb/states.py +++ b/vllm/distributed/eplb/states.py @@ -180,7 +180,8 @@ def build( def step(self, model: MixtureOfExperts, - is_dummy: bool = False) -> tuple[float, float, float]: + is_dummy: bool = False, + is_profile: bool = False) -> tuple[float, float, float]: """ Step the EPLB state. @@ -189,6 +190,9 @@ def step(self, is_dummy (bool): If `True`, this is a dummy step and the load metrics recorded in this forward pass will not count. Defaults to `False`. + is_profile (bool): If `True`, perform a dummy rearrangement + with maximum communication cost. This is used in `profile_run` + to reserve enough memory for the communication buffer. Returns: (avg_tokens, max_tokens, balancedness) (tuple[float, float, float]): @@ -198,6 +202,10 @@ def step(self, - `balancedness`: The ratio of average load to maximum load. """ + if is_profile: + self.rearrange(model, is_profile=True) + return 0.0, 0.0, 0.0 + if is_dummy: # Do not record load metrics for dummy steps self.expert_load_pass.zero_() @@ -246,7 +254,9 @@ def step(self, return avg_tokens, max_tokens, balancedness - def rearrange(self, model: MixtureOfExperts) -> None: + def rearrange(self, + model: MixtureOfExperts, + is_profile: bool = False) -> None: """ Rearrange the experts according to the current load. """ @@ -259,7 +269,8 @@ def rearrange(self, model: MixtureOfExperts) -> None: if is_main_rank: torch.cuda.synchronize() time_start = time.perf_counter() - logger.info("Rearranging experts...") + logger.info("Rearranging experts %s...", + "(profile)" if is_profile else "") # This mapping is only used here, so we do not store it in the state physical_expert_start = ep_rank * model.num_local_physical_experts @@ -316,17 +327,20 @@ def rearrange(self, model: MixtureOfExperts) -> None: new_physical_to_logical_map, model.expert_weights, ep_group, + is_profile, ) - self.physical_to_logical_map.copy_(new_physical_to_logical_map) - self.logical_to_physical_map.copy_(new_logical_to_physical_map) - self.logical_replica_count.copy_(new_logical_replica_count) + if not is_profile: + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + self.logical_to_physical_map.copy_(new_logical_to_physical_map) + self.logical_replica_count.copy_(new_logical_replica_count) if is_main_rank: assert time_start is not None torch.cuda.synchronize() time_end = time.perf_counter() logger.info( - "Rearranged experts in %.2f seconds.", + "Rearranged experts %s in %.2f seconds.", + "(profile)" if is_profile else "", time_end - time_start, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 605050b56501..ef7b6f4aff36 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1156,7 +1156,9 @@ def sync_and_slice_intermediate_tensors( for k, v in self.intermediate_tensors.items() }) - def eplb_step(self, is_dummy: bool = False) -> None: + def eplb_step(self, + is_dummy: bool = False, + is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1165,7 +1167,7 @@ def eplb_step(self, is_dummy: bool = False) -> None: assert is_mixture_of_experts(self.model) avg_tokens, max_tokens, balancedness = \ - self.eplb_state.step(self.model, is_dummy) + self.eplb_state.step(self.model, is_dummy, is_profile) if get_ep_group().is_first_rank: logger.debug( @@ -1808,6 +1810,7 @@ def _dummy_run( num_tokens: int, skip_attn: bool = True, skip_eplb: bool = False, + is_profile: bool = False, ) -> torch.Tensor: # Padding for DP @@ -1905,7 +1908,7 @@ def _dummy_run( # This is necessary to avoid blocking DP if not skip_eplb: - self.eplb_step(is_dummy=True) + self.eplb_step(is_dummy=True, is_profile=is_profile) logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] @@ -2057,7 +2060,8 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - hidden_states = self._dummy_run(self.max_num_tokens) + # Do not skip EPLB in profile run + hidden_states = self._dummy_run(self.max_num_tokens, is_profile=True) if get_pp_group().is_last_rank: sampler_output = self._dummy_sampler_run(hidden_states) else: From f5d171f5edefab408f07953a3cff3632c54edf33 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 19 Jun 2025 21:53:42 -0700 Subject: [PATCH 35/57] [Feature] Constrain EPLB to main models Otherwise, if EPLB is enabled, `FusedMoE` in MTP model will also try to use expert mapping and record load metrics, which is not wanted. Later work should consider adding support for `EplbState` to register multiple models, hence we can do EPLB on both the main model and the MTP model. Signed-off-by: Bowen Wang --- vllm/model_executor/models/deepseek_v2.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 898371e3e681..57cd6fc5213a 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -102,6 +102,7 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -132,7 +133,7 @@ def __init__( # Currently, `n_redundant_experts` equals to `n_extra_experts`. vllm_config = get_current_vllm_config() parallel_config = vllm_config.parallel_config - self.enable_eplb = parallel_config.enable_eplb + self.enable_eplb = enable_eplb self.n_extra_experts = parallel_config.num_extra_experts self.n_physical_experts = self.n_routed_experts + self.n_extra_experts @@ -531,6 +532,7 @@ def __init__( model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -571,6 +573,7 @@ def __init__( config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, ) else: self.mlp = DeepseekV2MLP( @@ -643,6 +646,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.vocab_size = config.vocab_size @@ -664,6 +668,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): model_config=model_config, cache_config=cache_config, quant_config=quant_config, + enable_eplb=enable_eplb, ), prefix=f"{prefix}.layers") @@ -731,14 +736,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.expert_weights = [] # Set MoE hyperparameters - # TODO(bowen): Add support for MTP layers self.num_moe_layers = (config.num_hidden_layers - config.first_k_dense_replace) self.num_expert_groups = config.n_group self.moe_layers: list[FusedMoE] = [] for layer in self.model.layers: - # TODO(bowen): Add support for MTP layers assert isinstance(layer, DeepseekV2DecoderLayer) if isinstance(layer.mlp, DeepseekV2MoE): self.moe_layers.append(layer.mlp.experts) @@ -768,8 +771,6 @@ def set_eplb_state( logical_replica_count=logical_replica_count, ) - # TODO(bowen): Add support for MTP layers - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) From aaa66a26477d41c7041806401b68bfbe0a3d992a Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 19 Jun 2025 22:24:22 -0700 Subject: [PATCH 36/57] [Refactor] Move out `EplbState` in model runner from classvars Signed-off-by: Bowen Wang --- vllm/v1/worker/gpu_model_runner.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fa4480185ad5..ad9382720a49 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -85,17 +85,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): - enable_eplb: bool = False - """ - Whether the expert parallelism load balancer is enabled. - """ - eplb_state: EplbState - """ - State of the expert parallelism load balancer. - - Will be lazily initialized when the model is loaded. - """ - def __init__( self, vllm_config: VllmConfig, @@ -158,6 +147,13 @@ def __init__( # Sampler self.sampler = Sampler() + self.eplb_state: Optional[EplbState] = None + """ + State of the expert parallelism load balancer. + + Will be lazily initialized when the model is loaded. + """ + # Lazy initializations # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache @@ -1191,6 +1187,7 @@ def eplb_step(self, if not self.parallel_config.enable_eplb: return + assert self.eplb_state is not None assert is_mixture_of_experts(self.model) avg_tokens, max_tokens, balancedness = \ self.eplb_state.step(self.model, is_dummy, is_profile) @@ -1696,7 +1693,6 @@ def load_model(self) -> None: if is_mixture_of_experts( self.model) and self.parallel_config.enable_eplb: - self.enable_eplb = True logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( From 4e346be58b13973c05194593657334c2d815d804 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 15:00:35 -0700 Subject: [PATCH 37/57] [Style] Rename `--num-extra-experts` to `--num-redundant-experts` Signed-off-by: Bowen Wang --- vllm/config.py | 15 ++++++++------- vllm/engine/arg_utils.py | 8 ++++---- vllm/model_executor/models/deepseek_v2.py | 8 +++----- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 8fa8a42b9248..2021192da099 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1768,7 +1768,7 @@ class ParallelConfig: """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False """Enable expert parallelism load balancing for MoE layers.""" - num_extra_experts: int = 0 + num_redundant_experts: int = 0 """Number of redundant experts to use for expert parallelism.""" eplb_window_size: int = 1000 """Window size for expert load recording.""" @@ -1918,14 +1918,15 @@ def __post_init__(self) -> None: raise ValueError( "Expert parallelism load balancing is only supported on " "CUDA devices now.") - if self.num_extra_experts < 0: + if self.num_redundant_experts < 0: raise ValueError( - "num_extra_experts must be non-negative, but got " - f"{self.num_extra_experts}.") + "num_redundant_experts must be non-negative, but got " + f"{self.num_redundant_experts}.") else: - if self.num_extra_experts != 0: - raise ValueError("num_extra_experts should be used with EPLB." - f"{self.num_extra_experts}.") + if self.num_redundant_experts != 0: + raise ValueError( + "num_redundant_experts should be used with EPLB." + f"{self.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8af66236d9a2..aae85e411349 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -321,7 +321,7 @@ class EngineArgs: data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_eplb: bool = ParallelConfig.enable_eplb - num_extra_experts: int = ParallelConfig.num_extra_experts + num_redundant_experts: int = ParallelConfig.num_redundant_experts eplb_window_size: int = ParallelConfig.eplb_window_size eplb_step_interval: int = ParallelConfig.eplb_step_interval max_parallel_loading_workers: Optional[ @@ -672,8 +672,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--num-extra-experts", - **parallel_kwargs["num_extra_experts"]) + parallel_group.add_argument("--num-redundant-experts", + **parallel_kwargs["num_redundant_experts"]) parallel_group.add_argument("--eplb-window-size", **parallel_kwargs["eplb_window_size"]) parallel_group.add_argument("--eplb-step-interval", @@ -1148,7 +1148,7 @@ def create_engine_config( data_parallel_backend=data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, - num_extra_experts=self.num_extra_experts, + num_redundant_experts=self.num_redundant_experts, eplb_window_size=self.eplb_window_size, eplb_step_interval=self.eplb_step_interval, max_parallel_loading_workers=self.max_parallel_loading_workers, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 57cd6fc5213a..3c37ac36df52 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -130,16 +130,14 @@ def __init__( self.gate.e_score_correction_bias = None # Load balancing settings. - # Currently, `n_redundant_experts` equals to `n_extra_experts`. vllm_config = get_current_vllm_config() parallel_config = vllm_config.parallel_config self.enable_eplb = enable_eplb - self.n_extra_experts = parallel_config.num_extra_experts - self.n_physical_experts = self.n_routed_experts + self.n_extra_experts + self.n_redundant_experts = parallel_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = (self.n_physical_experts - - self.n_logical_experts) + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = (self.ep_rank * From 2496a543a3f6c94d776fcd78a3ebaf9f3a6b43b2 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 15:38:40 -0700 Subject: [PATCH 38/57] [Doc] Add glossary for different types of experts Signed-off-by: Bowen Wang --- vllm/distributed/eplb/eplb_state.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 248c4adeade4..dd42463aec02 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -1,6 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 """ Expert parallelism load balancer (EPLB) metrics and states. + +# Glossary + +- **Logical Expert**: An expert that is part of the model's logical structure. + It holds a set of weights and is replicated across multiple physical + experts. +- **Redundant Expert**: To achieve load balancing, for some popular logical + experts, we create additional copies of the expert weights. During inference, + each of these copies can be routed to by the same set of tokens. +- **Physical Expert**: An expert that is instantiated on a specific device. + It is a replica of a logical expert and can be rearranged across devices. + I.e., one logical expert may have multiple sets of weights initialized on + different devices, and each of these sets is a physical expert. +- **Local Physical Expert**: A physical expert that is instantiated on the + current device. + +For example: DeepSeek-R1 has 256 logical experts, so each MoE layer +has 256 sets of linear layer weights in the model parameters. If we add 32 +redundant experts, DeepSeek-R1 will have 256 + 32 = 288 physical experts in +total. And when deploying, we'll have 288 sets of linear layer weights for each +MoE layer. If we have 32 EP ranks, then each GPU will hold 288 / 32 = 9 local +physical experts. """ import time From 9916913147269b5fbe4531d837670461672b5d9c Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 15:41:59 -0700 Subject: [PATCH 39/57] [Doc] Add staatements in `EplbState` that some var is just config Signed-off-by: Bowen Wang --- vllm/distributed/eplb/eplb_state.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index dd42463aec02..5d4595bd79b8 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -84,7 +84,10 @@ class EplbState: expert_load_window_step: int = 0 """Current step in the sliding window.""" expert_load_window_size: int = 0 - """Size of the expert load sliding window.""" + """ + Size of the expert load sliding window. + This is a constant and is taken from the config. + """ expert_rearrangement_step: int = 0 """ @@ -92,7 +95,10 @@ class EplbState: Will trigger a rearrangement if it exceeds the threshold. """ expert_rearrangement_step_interval: int = 0 - """Interval for expert rearrangement steps.""" + """ + Interval for expert rearrangement steps. + This is a constant and is taken from the config. + """ @staticmethod def build_initial_global_physical_to_logical_map( From 420cb99fa193b13005fb15e2ffdc8231bca2df7b Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 15:46:00 -0700 Subject: [PATCH 40/57] [Doc] Add notes on synchronization of rearrangement step Signed-off-by: Bowen Wang --- vllm/distributed/eplb/eplb_state.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 5d4595bd79b8..86e4b6a19f47 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -82,7 +82,12 @@ class EplbState: Shape: (window_size, num_moe_layers, num_local_physical_experts) """ expert_load_window_step: int = 0 - """Current step in the sliding window.""" + """ + Current step in the sliding window. + + Different from `expert_rearrangement_step`, each EP rank may have its own + `expert_load_window_step`. + """ expert_load_window_size: int = 0 """ Size of the expert load sliding window. @@ -93,6 +98,11 @@ class EplbState: """ Steps after last rearrangement. Will trigger a rearrangement if it exceeds the threshold. + + NOTE: Keep in mind that all EP ranks need to have the same + `expert_rearrangement_step` value to ensure synchronization. + Otherwise, the rearrangement will hang at collective + communication calls. """ expert_rearrangement_step_interval: int = 0 """ From ff368a16143563caf5a99f27d7e4fbc92a009de8 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 16:02:13 -0700 Subject: [PATCH 41/57] [Doc] Add examples for expert mappings Signed-off-by: Bowen Wang --- vllm/distributed/eplb/eplb_state.py | 35 +++++++++++++++++++++++++ vllm/distributed/eplb/rebalance_algo.py | 3 +++ 2 files changed, 38 insertions(+) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 86e4b6a19f47..f710568b276b 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -52,6 +52,16 @@ class EplbState: Mapping from physical experts to logical experts. Shape: (num_moe_layers, num_physical_experts) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[0, 1, 2, 3, 0, 1], + [0, 2, 0, 1, 0, 3]] + ``` """ logical_to_physical_map: torch.Tensor """ @@ -60,12 +70,37 @@ class EplbState: This is a sparse matrix, where -1 indicates no mapping. Shape: (num_moe_layers, num_logical_experts, num_redundant_experts + 1) + + # Example + + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the mapping could look like this: + + ``` + [[[0, 4, -1], + [1, 5, -1], + [2, -1, -1], + [3, -1, -1]], + [[0, 2, 4], + [3, -1, -1], + [1, -1, -1], + [5, -1, -1]]] + ``` """ logical_replica_count: torch.Tensor """ Number of replicas for each logical expert. + This is exactly the non-`-1` count in the `logical_to_physical_map`. Shape: (num_moe_layers, num_logical_experts) + + # Example + For a 2-layer MoE model with 6 physical experts and 4 logical experts on 3 + EP ranks, the count could look like this: + + ``` + [[2, 2, 1, 1], + [3, 1, 1, 1]] """ expert_load_pass: torch.Tensor diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index a27678f3ec00..7ad6d566b55b 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -6,6 +6,9 @@ The rearrangement algorithm is adapted from [DeepSeek EPLB](https://github.com/deepseek-ai/eplb). + +Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example +on how the EPLB algorithm works. """ import torch From 425d56c3da772ad466f46dc2cc1ff9c1654e6e34 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 16:06:17 -0700 Subject: [PATCH 42/57] [Doc] Add explanation on why picking the last layer for MoE config Signed-off-by: Bowen Wang --- vllm/model_executor/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 3c37ac36df52..f712b626c74c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -744,6 +744,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if isinstance(layer.mlp, DeepseekV2MoE): self.moe_layers.append(layer.mlp.experts) + # Pick last one layer since the first ones may be dense layers. example_moe = typing.cast( DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp) self.num_logical_experts = example_moe.n_logical_experts From 76fbdf82dbeaaaf1d811e463b31b75b5347fb186 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 16:13:01 -0700 Subject: [PATCH 43/57] [Refactor] Revert `fused_moe.py` since not used Signed-off-by: Bowen Wang --- .../layers/fused_moe/fused_moe.py | 53 ++++++++----------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d15f408e8c2a..437e80696ac6 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -4,7 +4,7 @@ import functools import json import os -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional import torch @@ -1431,8 +1431,7 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, - return_topk_ids: bool = False, -) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: +) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. @@ -1478,7 +1477,6 @@ def fused_moe( a2. - block_shape: (Optional[list[int]]): Optional block size for block-wise quantization. - - return_topk_ids (bool): If True, return the top-k expert IDs Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -1496,32 +1494,27 @@ def fused_moe( topk_weights, topk_ids = custom_routing_function( hidden_states, gating_output, topk, renormalize) - result = fused_experts(hidden_states, - w1, - w2, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1_scale=a1_scale, - a2_scale=a2_scale, - block_shape=block_shape) - - if return_topk_ids: - return result, topk_ids - else: - return result + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + activation=activation, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape) class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): From 6777877eb963b63f0776ebc6161e6839b1ab9aaa Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 16:21:36 -0700 Subject: [PATCH 44/57] [Doc] Add explanations for calling points of `_dummy_run` Signed-off-by: Bowen Wang --- vllm/v1/worker/gpu_model_runner.py | 10 +++++++++- vllm/v1/worker/gpu_worker.py | 2 ++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 88bcefd77de5..4b1183824c07 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2026,7 +2026,13 @@ def _dummy_run( assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) - # This is necessary to avoid blocking DP + # This is necessary to avoid blocking DP. + # For dummy runs, we typically skip EPLB since we don't have any real + # requests to process. + # However, in DP settings, there may be cases when some DP ranks do + # not have any requests to process, so they're executing dummy batches. + # In such cases, we still have to trigger EPLB to make sure + # ranks execute the rearrangement in synchronization. if not skip_eplb: self.eplb_step(is_dummy=True, is_profile=is_profile) @@ -2222,6 +2228,7 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) + # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ = self._dummy_run(self.max_num_tokens, is_profile=True) if get_pp_group().is_last_rank: @@ -2257,6 +2264,7 @@ def capture_model(self) -> None: for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), desc="Capturing CUDA graphs", total=len(self.cudagraph_batch_sizes)): + # We skip EPLB here since we don't want to record dummy metrics for _ in range( self.compilation_config.cudagraph_num_of_warmups): self._dummy_run(num_tokens, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4f12a4620555..9e7e44d06861 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -259,6 +259,7 @@ def compile_or_warm_up_model(self) -> None: x for x in warmup_sizes if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes ] + # We skip EPLB here since we don't want to record dummy metrics for size in sorted(warmup_sizes, reverse=True): logger.info("Compile and warming up model for size %d", size) self.model_runner._dummy_run(size, skip_eplb=True) @@ -274,6 +275,7 @@ def compile_or_warm_up_model(self) -> None: max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) + # We skip EPLB here since we don't want to record dummy metrics hidden_states, last_hidden_states = \ self.model_runner._dummy_run( num_tokens=max_num_reqs, From 12401b19b11850ec8e775c4e5ec7d7fad7efd8d0 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 16:29:30 -0700 Subject: [PATCH 45/57] [Doc] Add comments on when do real communication happen Signed-off-by: Bowen Wang --- vllm/distributed/eplb/rebalance_execute.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index b1ae79897bec..cf173c734afd 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -211,12 +211,13 @@ def shuffle_layer( ) for weight in expert_weights_buffer ] + # 4. Execute the P2P operations. The real communication happens here. if p2p_ops: reqs = batch_isend_irecv(p2p_ops) for req in reqs: req.wait() - # 4. Copy the weights from the buffer back to the original weights. + # 5. Copy the weights from the buffer back to the original weights. for dst in range(num_local_experts): if is_unchanged[dst]: continue From 80b3a1bd44814ad257cb339d65b03348f72fa6b1 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 16:33:32 -0700 Subject: [PATCH 46/57] [Doc] Add comments on only last `eplb_window_size` steps will be used When the rearrangement interval is larger than window size. Signed-off-by: Bowen Wang --- vllm/config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 2021192da099..ff2b80385be9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1773,7 +1773,12 @@ class ParallelConfig: eplb_window_size: int = 1000 """Window size for expert load recording.""" eplb_step_interval: int = 3000 - """Interval for rearranging experts in expert parallelism.""" + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `eplb_window_size` steps will be used for rearranging experts. + """ max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model From 3ea6f2c5dde98cbfa5519d17c9b0afb6d303ad27 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 16:50:30 -0700 Subject: [PATCH 47/57] [Feature] Disable balancedness logging by default Since this adds some considerable amount of communication overhead Signed-off-by: Bowen Wang --- vllm/config.py | 5 +++ vllm/distributed/eplb/eplb_state.py | 65 ++++++++++++++++------------- vllm/engine/arg_utils.py | 4 ++ vllm/v1/worker/gpu_model_runner.py | 15 ++++--- 4 files changed, 52 insertions(+), 37 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ff2b80385be9..d35fb6452daf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1779,6 +1779,11 @@ class ParallelConfig: Note that if this is greater than the EPLB window size, only the metrics of the last `eplb_window_size` steps will be used for rearranging experts. """ + eplb_log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index f710568b276b..2a30f0b63243 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -254,7 +254,8 @@ def build( def step(self, model: MixtureOfExperts, is_dummy: bool = False, - is_profile: bool = False) -> tuple[float, float, float]: + is_profile: bool = False, + log_stats: bool = False) -> None: """ Step the EPLB state. @@ -266,10 +267,10 @@ def step(self, is_profile (bool): If `True`, perform a dummy rearrangement with maximum communication cost. This is used in `profile_run` to reserve enough memory for the communication buffer. + log_stats (bool): If `True`, log the expert load metrics. - Returns: - (avg_tokens, max_tokens, balancedness) (tuple[float, float, float]): - The returned metrics are all summed up across layers. + # Stats + The metrics are all summed up across layers. - `avg_tokens`: The average load across ranks. - `max_tokens`: The maximum load across ranks. - `balancedness`: The ratio of average load to maximum load. @@ -277,34 +278,42 @@ def step(self, if is_profile: self.rearrange(model, is_profile=True) - return 0.0, 0.0, 0.0 + return if is_dummy: # Do not record load metrics for dummy steps self.expert_load_pass.zero_() - # `num_tokens`: (num_moe_layers,) - num_tokens = self.expert_load_pass.sum(dim=-1) - # Collect load metrics from all ranks - ep_group = get_ep_group().device_group - num_tokens_list = [ - torch.empty_like(num_tokens) for _ in range(ep_group.size()) - ] - all_gather(num_tokens_list, num_tokens, group=ep_group) - # Stack to get (num_ranks, num_moe_layers) - num_tokens_per_rank = torch.stack(num_tokens_list).float() - - # Compute balancedness ratio: - # for each layer: - # (mean load across ranks) / (max load across ranks) - avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) - max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0) - - # Just to make type checker happy - tokens_tensors: list[float] = torch.stack( - [avg_tokens_tensor, max_tokens_tensor]).tolist() - avg_tokens, max_tokens = tokens_tensors - balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 + if log_stats: + # `num_tokens`: (num_moe_layers,) + num_tokens = self.expert_load_pass.sum(dim=-1) + + # Collect load metrics from all ranks + ep_group = get_ep_group().device_group + num_tokens_list = [ + torch.empty_like(num_tokens) for _ in range(ep_group.size()) + ] + all_gather(num_tokens_list, num_tokens, group=ep_group) + # Stack to get (num_ranks, num_moe_layers) + num_tokens_per_rank = torch.stack(num_tokens_list).float() + + # Compute balancedness ratio: + # for each layer: + # (mean load across ranks) / (max load across ranks) + avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum( + dim=0) + + # Just to make type checker happy + tokens_tensors: list[float] = torch.stack( + [avg_tokens_tensor, max_tokens_tensor]).tolist() + avg_tokens, max_tokens = tokens_tensors + balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 + + if ep_group.rank() == 0: + logger.info( + "EPLB step: avg_tokens=%.2f, max_tokens=%d, " + "balancedness=%.4f", avg_tokens, max_tokens, balancedness) # Update the expert load sliding window if not is_dummy: @@ -325,8 +334,6 @@ def step(self, self.expert_rearrangement_step = 0 self.rearrange(model) - return avg_tokens, max_tokens, balancedness - def rearrange(self, model: MixtureOfExperts, is_profile: bool = False) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aae85e411349..14660b2840a5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -324,6 +324,7 @@ class EngineArgs: num_redundant_experts: int = ParallelConfig.num_redundant_experts eplb_window_size: int = ParallelConfig.eplb_window_size eplb_step_interval: int = ParallelConfig.eplb_step_interval + eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -678,6 +679,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["eplb_window_size"]) parallel_group.add_argument("--eplb-step-interval", **parallel_kwargs["eplb_step_interval"]) + parallel_group.add_argument("--eplb-log-balancedness", + **parallel_kwargs["eplb_log_balancedness"]) parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1151,6 +1154,7 @@ def create_engine_config( num_redundant_experts=self.num_redundant_experts, eplb_window_size=self.eplb_window_size, eplb_step_interval=self.eplb_step_interval, + eplb_log_balancedness=self.eplb_log_balancedness, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4b1183824c07..e4584a118257 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -26,7 +26,7 @@ has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( - get_ep_group, get_pp_group, get_tp_group, graph_capture, + get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) @@ -1199,13 +1199,12 @@ def eplb_step(self, assert self.eplb_state is not None assert is_mixture_of_experts(self.model) - avg_tokens, max_tokens, balancedness = \ - self.eplb_state.step(self.model, is_dummy, is_profile) - - if get_ep_group().is_first_rank: - logger.debug( - "Model step: avg_tokens=%.2f, max_tokens=%d, " - "balancedness=%.4f", avg_tokens, max_tokens, balancedness) + self.eplb_state.step( + self.model, + is_dummy, + is_profile, + log_stats=self.parallel_config.eplb_log_balancedness, + ) def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: From aff799187fae2531e7fe464eb5314a7d615d2a20 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 17:32:18 -0700 Subject: [PATCH 48/57] [Style] Rename shadowed variables to make linter happy Signed-off-by: Bowen Wang --- vllm/distributed/eplb/eplb_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 2a30f0b63243..3af6ed9b37da 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -175,13 +175,13 @@ def build( """ Build the initial EPLB state. """ - physical_to_logical_map = ( + physical_to_logical_map_list = ( cls.build_initial_global_physical_to_logical_map( model.num_routed_experts, model.num_redundant_experts, )) physical_to_logical_map = torch.tensor( - physical_to_logical_map, + physical_to_logical_map_list, device=device, ) logical_to_physical_map = torch.full( From 8ac089e7a2e8349949f254c5467109feb7777f7b Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Mon, 23 Jun 2025 18:14:10 -0700 Subject: [PATCH 49/57] [Style] Add parameters of `apply` for subclasses of `FusedMoEMethodBase` To make linter happy Signed-off-by: Bowen Wang --- vllm/model_executor/layers/fused_moe/layer.py | 31 +++++++------- .../layers/quantization/awq_marlin.py | 8 ++++ .../compressed_tensors_moe.py | 42 +++++++++++++++++++ .../layers/quantization/experts_int8.py | 8 ++++ .../layers/quantization/gguf.py | 8 ++++ .../layers/quantization/gptq_marlin.py | 8 ++++ .../layers/quantization/modelopt.py | 8 ++++ .../layers/quantization/moe_wna16.py | 8 ++++ .../layers/quantization/quark/quark_moe.py | 8 ++++ 9 files changed, 114 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 812c4b6fa803..cf29ce5592f8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -578,7 +578,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + return self.forward( x=x, layer=layer, @@ -1564,12 +1572,6 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): hidden_states = full_hidden_states[chunk_start:chunk_end, :] router_logits = full_router_logits[chunk_start:chunk_end, :] - eplb_kwargs = { - "enable_eplb": True, - "expert_load_view": self.expert_load_view, - "logical_to_physical_map": self.logical_to_physical_map, - "logical_replica_count": self.logical_replica_count, - } if self.enable_eplb else {} assert (self.batched_hidden_states.size(0) # type: ignore >= chunk_size) assert (self.batched_router_logits.size(0) # type: ignore @@ -1597,7 +1599,10 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, - **eplb_kwargs, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if not skip_result_store: @@ -1638,13 +1643,6 @@ def forward_impl(self, hidden_states: torch.Tensor, hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) - eplb_kwargs = { - "enable_eplb": True, - "expert_load_view": self.expert_load_view, - "logical_to_physical_map": self.logical_to_physical_map, - "logical_replica_count": self.logical_replica_count, - } if self.enable_eplb else {} - # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1662,7 +1660,10 @@ def forward_impl(self, hidden_states: torch.Tensor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, - **eplb_kwargs, + enable_eplb=self.enable_eplb, + expert_load_view=self.expert_load_view, + logical_to_physical_map=self.logical_to_physical_map, + logical_replica_count=self.logical_replica_count, ) if do_naive_dispatch_combine: diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 56d803c6baf1..aff54bc495b2 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -482,7 +482,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `AWQMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index f14131c5f05b..7703b9e687c4 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -331,7 +331,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoEMethod` yet.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -593,7 +601,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -722,7 +738,16 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsW8A8Int8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -1012,7 +1037,16 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for " + "`CompressedTensorsWNA16MarlinMoEMethod` yet.") + assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") assert not apply_router_weight_on_input, ( @@ -1228,7 +1262,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError("EPLB not supported for " + "`CompressedTensorsWNA16MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 01b0064f0805..47eca80609e0 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -117,7 +117,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ExpertsInt8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9c8f74545d37..86da04c39989 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -520,7 +520,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GGUFMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index f92ebdea986d..9cbaf5d4a6e9 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -609,7 +609,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `GPTQMarlinMoEMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." if apply_router_weight_on_input: raise NotImplementedError( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 3f79b203aa17..e35db5b31dba 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -664,7 +664,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ): + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + if self.use_marlin: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 3aa23f068257..c5055a02fa3d 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -297,7 +297,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `MoeWNA16Method` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 4c2da4c8b04e..a040c430cbca 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -205,7 +205,15 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") + from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( From a6a4a3aca083bfe1afd42deb6a5e356738b6f693 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Tue, 24 Jun 2025 16:07:53 -0700 Subject: [PATCH 50/57] [Test] Add test for EPLB algo Signed-off-by: Bowen Wang --- .buildkite/test-pipeline.yaml | 8 + tests/distributed/test_eplb_algo.py | 292 ++++++++++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 tests/distributed/test_eplb_algo.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fe775bb370f2..1f1f7a0927c0 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -168,6 +168,14 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd +- label: EPLB Algorithm Test + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_algo.py + commands: + - pytest -v -s distributed/test_eplb_algo.py + - label: Metrics, Tracing Test # 10min mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py new file mode 100644 index 000000000000..e47ccba99c81 --- /dev/null +++ b/tests/distributed/test_eplb_algo.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.distributed.eplb.rebalance_algo import rebalance_experts + + +def test_basic_rebalance(): + """Test basic rebalancing functionality""" + # Example from https://github.com/deepseek-ai/eplb + weight = torch.tensor([ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ]) + + num_layers = weight.shape[0] + num_replicas = 16 + num_groups = 4 + num_nodes = 2 + num_gpus = 8 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify output shapes + assert phy2log.shape == ( + 2, + 16, + ), f"Expected `phy2log` shape (2, 16), got {phy2log.shape}" + assert (log2phy.shape[0] == 2 + ), f"Expected `log2phy` first dimension 2, got {log2phy.shape[0]}" + assert ( + log2phy.shape[1] == 12 + ), f"Expected `log2phy` second dimension 12, got {log2phy.shape[1]}" + assert logcnt.shape == ( + 2, + 12, + ), f"Expected `logcnt` shape (2, 12), got {logcnt.shape}" + + # Verify physical to logical expert mapping range is correct + assert torch.all(phy2log >= 0) and torch.all( + phy2log < 12), "Physical to logical mapping should be in range [0, 12)" + + # Verify expert count reasonableness + assert torch.all( + logcnt >= 1), "Each logical expert should have at least 1 replica" + assert ( + torch.sum(logcnt, dim=1).sum() == num_replicas * + num_layers), f"Total replicas should be {num_replicas * num_layers}" + + # Verify expected output + expected_phy2log = torch.tensor([ + [5, 6, 5, 7, 8, 4, 3, 4, 10, 9, 10, 2, 0, 1, 11, 1], + [7, 10, 6, 8, 6, 11, 8, 9, 2, 4, 5, 1, 5, 0, 3, 1], + ]) + assert torch.all(phy2log == expected_phy2log) + + expected_logcnt = torch.tensor([[1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1], + [1, 2, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1]]) + assert torch.all(logcnt == expected_logcnt) + + +def test_single_gpu_case(): + """Test single GPU case""" + weight = torch.tensor([[10, 20, 30, 40]]) + num_replicas = 4 + num_groups = 1 + num_nodes = 1 + num_gpus = 1 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 4) + assert log2phy.shape[0] == 1 + assert log2phy.shape[1] == 4 + assert logcnt.shape == (1, 4) + + # Verify all logical experts are mapped + assert set(phy2log[0].tolist()) == {0, 1, 2, 3} + + +def test_equal_weights(): + """Test case with equal weights""" + weight = torch.tensor([[50, 50, 50, 50, 50, 50, 50, 50]]) + num_replicas = 8 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 8) + + # With equal weights, each expert should have exactly one replica + assert torch.all( + logcnt == 1 + ), "With equal weights and no replication, " \ + "each expert should have exactly 1 replica" + + +def test_extreme_weight_imbalance(): + """Test extreme weight imbalance case""" + weight = torch.tensor([[1000, 1, 1, 1, 1, 1, 1, 1]]) + num_replicas = 12 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (1, 12) + assert logcnt.shape == (1, 8) + + # Expert with highest weight (index 0) should have more replicas + assert ( + logcnt[0, 0] + > logcnt[0, 1]), "Expert with highest weight should have more replicas" + + +def test_multiple_layers(): + """Test multiple layers case""" + weight = torch.tensor([ + [10, 20, 30, 40, 50, 60], # First layer + [60, 50, 40, 30, 20, 10], # Second layer (opposite weight pattern) + [25, 25, 25, 25, 25, 25], # Third layer (equal weights) + ]) + num_replicas = 8 + num_groups = 2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify shapes + assert phy2log.shape == (3, 8) + assert logcnt.shape == (3, 6) + + # Verify expert allocation is reasonable for each layer + for layer in range(3): + assert torch.all(phy2log[layer] >= 0) and torch.all( + phy2log[layer] < 6 + ), f"Layer {layer} physical to logical mapping" \ + "should be in range [0, 6)" + assert (torch.sum(logcnt[layer]) == num_replicas + ), f"Layer {layer} total replicas should be {num_replicas}" + + +def test_parameter_validation(): + """Test parameter validation""" + weight = torch.tensor([[10, 20, 30, 40]]) + + # Test non-divisible case - this should handle normally without throwing + # errors because the function will fall back to global load balancing + # strategy + phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4) + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 4) + + # Test cases that will actually cause errors: + # num_physical_experts not divisible by num_gpus + with pytest.raises(AssertionError): + rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 + + +def test_small_scale_hierarchical(): + """Test small-scale hierarchical load balancing""" + weight = torch.tensor([ + [100, 50, 200, 75, 150, 25, 300, 80], # 8 experts + ]) + num_replicas = 12 + num_groups = 4 # 4 groups, 2 experts each + num_nodes = 2 # 2 nodes + num_gpus = 4 # 4 GPUs + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Verify basic constraints + assert phy2log.shape == (1, 12) + assert logcnt.shape == (1, 8) + assert torch.sum(logcnt) == num_replicas + assert torch.all(logcnt >= 1) + + # Expert with highest weight should have more replicas + max_weight_expert = torch.argmax(weight[0]) + assert (logcnt[0, max_weight_expert] + >= 2), "Highest weight expert should have multiple replicas" + + +def test_global_load_balance_fallback(): + """Test global load balancing fallback case""" + # When num_groups % num_nodes != 0, should fall back to global load + # balancing + weight = torch.tensor([[10, 20, 30, 40, 50, 60]]) + num_replicas = 8 + num_groups = 3 # Cannot be divided evenly by num_nodes=2 + num_nodes = 2 + num_gpus = 4 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Should work normally, just using global load balancing strategy + assert phy2log.shape == (1, 8) + assert logcnt.shape == (1, 6) + assert torch.sum(logcnt) == num_replicas + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_device_compatibility(device): + """Test device compatibility""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + weight = torch.tensor([[10, 20, 30, 40]], device=device) + num_replicas = 6 + num_groups = 2 + num_nodes = 1 + num_gpus = 2 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + + # Function will convert to CPU internally, but should handle different + # device inputs normally + assert phy2log.shape == (1, 6) + assert logcnt.shape == (1, 4) + + +def test_additional_cases(): + """Test more edge cases and different parameter combinations""" + + # Test case 1: Large-scale distributed setup + weight1 = torch.tensor( + [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]) + phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) + + assert phy2log1.shape == (1, 24) + assert logcnt1.shape == (1, 16) + assert torch.sum(logcnt1) == 24 + + # Test case 2: Different weight distributions + weight2 = torch.tensor([ + [200, 150, 100, 50, 25, 12], # Decreasing weights + [12, 25, 50, 100, 150, 200], # Increasing weights + ]) + phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) + + assert phy2log2.shape == (2, 10) + assert logcnt2.shape == (2, 6) + + # Verify high-weight experts have more replicas + for layer in range(2): + max_weight_idx = torch.argmax(weight2[layer]) + assert logcnt2[layer, max_weight_idx] >= 2 + + +if __name__ == "__main__": + weight = torch.tensor([ + [90, 132, 40, 61, 104, 165, 39, 4, 73, 56, 183, 86], + [20, 107, 104, 64, 19, 197, 187, 157, 172, 86, 16, 27], + ]) + + num_replicas = 16 + num_groups = 4 + num_nodes = 2 + num_gpus = 8 + + phy2log, log2phy, logcnt = rebalance_experts(weight, num_replicas, + num_groups, num_nodes, + num_gpus) + print(phy2log) + + test_basic_rebalance() From 1ed45b257b2e658db1bea77c2eeada5390c10d38 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 25 Jun 2025 03:35:12 -0700 Subject: [PATCH 51/57] [Test] Add test for EPLB execute Signed-off-by: Bowen Wang --- .buildkite/test-pipeline.yaml | 9 + tests/distributed/test_eplb_execute.py | 502 +++++++++++++++++++++++++ 2 files changed, 511 insertions(+) create mode 100644 tests/distributed/test_eplb_execute.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1f1f7a0927c0..f4e75c66b73d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -176,6 +176,15 @@ steps: commands: - pytest -v -s distributed/test_eplb_algo.py +- label: EPLB Execution Test # 5min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_execute.py + commands: + - pytest -v -s distributed/test_eplb_execute.py + - label: Metrics, Tracing Test # 10min mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py new file mode 100644 index 000000000000..1a14688c9807 --- /dev/null +++ b/tests/distributed/test_eplb_execute.py @@ -0,0 +1,502 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +import os +import random + +import pytest +import torch +import torch.distributed + +from vllm.distributed.eplb.rebalance_execute import ( + rearrange_expert_weights_inplace) +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + get_tp_group, + init_distributed_environment) +from vllm.utils import update_environment_variables + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes: list[multiprocessing.Process] = [] + for i in range(number_of_processes): + env: dict[str, str] = {} + env['RANK'] = str(i) + env['LOCAL_RANK'] = str(i) + env['WORLD_SIZE'] = str(number_of_processes) + env['LOCAL_WORLD_SIZE'] = str(number_of_processes) + env['MASTER_ADDR'] = 'localhost' + env['MASTER_PORT'] = '12345' + p = multiprocessing.Process(target=fn, args=(env, )) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function + def wrapped_fn(env): + update_environment_variables(env) + local_rank = os.environ['LOCAL_RANK'] + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + init_distributed_environment() + + # Ensure each worker process has the same random seed + random.seed(42) + torch.manual_seed(42) + + fn() + + return wrapped_fn + + +def create_expert_indices_with_redundancy( + num_layers: int, + num_logical_experts: int, + total_physical_experts: int, + redundancy_config: list[int], # redundancy for each logical expert +) -> torch.Tensor: + """ + Create expert indices with redundancy. + + Args: + num_layers: number of layers + num_logical_experts: number of logical experts + total_physical_experts: total number of physical experts + redundancy_config: redundancy for each logical expert + + Returns: + indices: Shape (num_layers, total_physical_experts) + """ + assert sum(redundancy_config) == total_physical_experts + assert len(redundancy_config) == num_logical_experts + + indices = torch.zeros(num_layers, total_physical_experts, dtype=torch.long) + + for layer in range(num_layers): + physical_pos = 0 + for logical_expert_id, redundancy in enumerate(redundancy_config): + for _ in range(redundancy): + indices[layer, physical_pos] = logical_expert_id + physical_pos += 1 + + # Shuffle the indices at dim 1 + for layer in range(num_layers): + indices[layer] = indices[layer][torch.randperm(indices.shape[1])] + + return indices + + +def create_expert_weights( + num_layers: int, + num_local_experts: int, + hidden_sizes: list[int], + rank: int, + device: torch.device, + physical_to_logical_mapping: torch.Tensor, +) -> list[list[torch.Tensor]]: + """ + Create fake expert weights tensor for testing. + + Use `arange` to generate predictable weights values, based on logical + expert ID. + All replicas of the same logical expert should have the same weights. + + Args: + physical_to_logical_mapping: Shape (num_layers, num_local_experts) + mapping[layer, physical_pos] = logical_expert_id + """ + expert_weights = [] + + for layer in range(num_layers): + layer_weights = [] + for weight_idx, hidden_size in enumerate(hidden_sizes): + weight_tensor = torch.zeros(num_local_experts, + hidden_size, + device=device, + dtype=torch.float32) + + for local_expert in range(num_local_experts): + # Get the logical expert ID for this physical expert + global_pos = rank * num_local_experts + local_expert + logical_expert_id = physical_to_logical_mapping[ + layer, global_pos].item() + + # Generate weights based on logical expert ID + # (so that all replicas of the same logical expert have the + # same weights) + base_value = (logical_expert_id * 1000 + layer * 100 + + weight_idx * 10) + weight_tensor[local_expert] = torch.arange(base_value, + base_value + + hidden_size, + device=device, + dtype=torch.float32) + + layer_weights.append(weight_tensor) + expert_weights.append(layer_weights) + + return expert_weights + + +def create_redundancy_config( + num_logical_experts: int, + num_physical_experts: int, +) -> list[int]: + """Create a redundancy configuration.""" + redundancy_config = [1] * num_logical_experts + remaining = num_physical_experts - num_logical_experts + # Randomly assign the remaining physical experts to the logical experts + for _ in range(remaining): + redundancy_config[random.choice(range(num_logical_experts))] += 1 + return redundancy_config + + +def verify_expert_weights_after_shuffle( + expert_weights: list[list[torch.Tensor]], + new_indices: torch.Tensor, + hidden_sizes: list[int], + ep_rank: int, + num_local_experts: int, +): + """Verify the weights after shuffling are correct.""" + num_layers = len(expert_weights) + + for layer in range(num_layers): + for weight_idx, hidden_size in enumerate(hidden_sizes): + weight_tensor = expert_weights[layer][weight_idx] + + for local_expert in range(num_local_experts): + # Calculate the global expert ID for this local expert + global_pos = ep_rank * num_local_experts + local_expert + expected_logical_expert = new_indices[layer, global_pos].item() + + # Check if the weights are correct + actual_weights = weight_tensor[local_expert] + expected_base = (expected_logical_expert * 1000 + layer * 100 + + weight_idx * 10) + expected_weights = torch.arange(expected_base, + expected_base + hidden_size, + device=actual_weights.device, + dtype=actual_weights.dtype) + + torch.testing.assert_close( + actual_weights, + expected_weights, + msg=f"Layer {layer}, weight {weight_idx}," + f"local expert {local_expert}: " + f"weights do not match. " + f"Expected logical expert {expected_logical_expert}") + + +def verify_redundant_experts_have_same_weights( + expert_weights: list[list[torch.Tensor]], + indices: torch.Tensor, + hidden_sizes: list[int], + world_size: int, + num_local_experts: int, +): + """ + Verify that all replicas of the same logical expert have the same weights. + """ + num_layers = len(expert_weights) + total_physical_experts = world_size * num_local_experts + + for layer in range(num_layers): + # Collect weights for all physical experts for each weight matrix + all_weights: list[torch.Tensor] = [] + + for weight_idx, hidden_size in enumerate(hidden_sizes): + # Create tensor to store all expert weights + # Shape: [total_physical_experts, hidden_size] + gathered_weights = torch.zeros( + total_physical_experts, + hidden_size, + device=expert_weights[layer][weight_idx].device, + dtype=expert_weights[layer][weight_idx].dtype) + + # Use all_gather to collect expert weights from current node + # expert_weights[layer][weight_idx] shape: + # [num_local_experts, hidden_size] + local_weights = expert_weights[layer][ + weight_idx] # [num_local_experts, hidden_size] + + # Split tensor along dim 0 into a list for all_gather + gathered_weights_list = torch.chunk(gathered_weights, + world_size, + dim=0) + + torch.distributed.all_gather( + # Output list: each element corresponds to one rank's weights + list(gathered_weights_list), + local_weights # Input: current rank's local weights + ) + + all_weights.append(gathered_weights) + + # Verify that all replicas of the same logical expert have the same + # weights + logical_expert_weights: dict[int, dict[int, torch.Tensor]] = {} + + for physical_pos in range(total_physical_experts): + logical_expert_id = int(indices[layer, physical_pos].item()) + + if logical_expert_id not in logical_expert_weights: + # First time encountering this logical expert, save its weights + logical_expert_weights[logical_expert_id] = { + weight_idx: all_weights[weight_idx][physical_pos] + for weight_idx in range(len(hidden_sizes)) + } + else: + # Verify that current physical expert's weights match the + # previously saved logical expert weights + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + all_weights[weight_idx][physical_pos], + logical_expert_weights[logical_expert_id][weight_idx], + msg=f"Layer {layer}, weight {weight_idx}," + f"logical expert {logical_expert_id}: " + f"Physical expert {physical_pos} has different weights" + f"than expected") + + +@pytest.mark.parametrize( + "world_size,num_layers,num_local_experts,num_logical_experts", + [ + # 2 GPU, 2 experts per GPU + # 3 logical experts, 4 physical experts, 1 redundant experts + (2, 1, 2, 3), + # 2 GPU, 3 experts per GPU + # 4 logical experts, 6 physical experts, 2 redundant experts + (2, 2, 3, 4), + # 2 GPU, 8 experts per GPU + # 16 logical experts, 16 physical experts, 0 redundant experts + (2, 4, 8, 16), + # 4 GPU, 2 experts per GPU + # 6 logical experts, 8 physical experts, 2 redundant experts + (4, 1, 2, 6), + # 4 GPU, 2 experts per GPU + # 5 logical experts, 8 physical experts, 3 redundant experts + (4, 2, 2, 5), + # 4 GPU, 8 experts per GPU + # 16 logical experts, 32 physical experts, 16 redundant experts + (4, 8, 8, 16), + ]) +def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, + num_local_experts, + num_logical_experts): + """Test the functionality of rearranging expert weights with redundancy.""" + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + # Initialize model parallel (using tensor parallel as an entrypoint + # to expert parallel) + ensure_model_parallel_initialized(tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + # Test parameters + total_physical_experts = world_size * num_local_experts + hidden_sizes = [32, 64] # Two different weight matrices + + # Create old expert indices (with redundancy) + redundancy_config = create_redundancy_config(num_logical_experts, + total_physical_experts) + + old_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + redundancy_config, + ) + + # Create new expert indices (with redundancy) + new_redundancy_config = create_redundancy_config( + num_logical_experts, total_physical_experts) + new_indices = create_expert_indices_with_redundancy( + num_layers, + num_logical_experts, + total_physical_experts, + new_redundancy_config, + ) + + # Create expert weights + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + old_indices) + + # Execute weight rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=False, + ) + + # Verify the rearrangement result + verify_expert_weights_after_shuffle( + expert_weights, + new_indices, + hidden_sizes, + ep_rank, + num_local_experts, + ) + + verify_redundant_experts_have_same_weights( + expert_weights, + new_indices, + hidden_sizes, + world_size, + num_local_experts, + ) + + distributed_run(worker_fn, world_size) + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_rearrange_expert_weights_no_change(world_size): + """ + Test that when the indices do not change, the weights should remain + unchanged. + """ + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + ensure_model_parallel_initialized(tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + num_layers = 2 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 # Some redundancy + hidden_sizes = [32, 64] + + # Create redundancy configuration + redundancy_config = [2] * num_logical_experts + + # Same indices - no change + indices = create_expert_indices_with_redundancy(num_layers, + num_logical_experts, + total_physical_experts, + redundancy_config) + + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + indices) + + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute rearrangement (should be no change) + rearrange_expert_weights_inplace( + indices, + indices, # Same indices + expert_weights, + ep_group, + is_profile=False) + + # Verify that the weights have not changed + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg=f"Layer {layer}, weight {weight_idx} should remain " + f"unchanged") + + distributed_run(worker_fn, world_size) + + +@pytest.mark.parametrize("world_size", [2, 4]) +def test_rearrange_expert_weights_profile_mode(world_size): + """Test profile mode (should not copy actual weights)""" + + if torch.cuda.device_count() < world_size: + pytest.skip(f"Need at least {world_size} GPUs to run the test") + + @worker_fn_wrapper + def worker_fn(): + ensure_model_parallel_initialized(tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) + + ep_group = get_tp_group().cpu_group + ep_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{ep_rank}") + + num_layers = 1 + num_local_experts = 2 + total_physical_experts = world_size * num_local_experts + num_logical_experts = total_physical_experts // 2 + hidden_sizes = [32] + + # Create different index distributions + old_redundancy = create_redundancy_config(num_logical_experts, + total_physical_experts) + new_redundancy = create_redundancy_config(num_logical_experts, + total_physical_experts) + + old_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + old_redundancy) + new_indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + new_redundancy) + + expert_weights = create_expert_weights(num_layers, num_local_experts, + hidden_sizes, ep_rank, device, + old_indices) + + # Save original weights + original_weights = [] + for layer_weights in expert_weights: + layer_copy = [] + for weight in layer_weights: + layer_copy.append(weight.clone()) + original_weights.append(layer_copy) + + # Execute profile mode rearrangement + rearrange_expert_weights_inplace( + old_indices, + new_indices, + expert_weights, + ep_group, + is_profile=True # Profile mode + ) + + # In profile mode, the weights should remain unchanged + for layer in range(num_layers): + for weight_idx in range(len(hidden_sizes)): + torch.testing.assert_close( + expert_weights[layer][weight_idx], + original_weights[layer][weight_idx], + msg="In profile mode, the weights should remain unchanged") + + distributed_run(worker_fn, world_size) From 4eeb0ff8f1476ed0a04a704609f51dfb1fde48a9 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 25 Jun 2025 14:56:42 -0700 Subject: [PATCH 52/57] [Style] Split some long lines Signed-off-by: Bowen Wang --- tests/distributed/test_eplb_execute.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 1a14688c9807..de9ed1eabbac 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -303,8 +303,9 @@ def test_rearrange_expert_weights_with_redundancy(world_size, num_layers, def worker_fn(): # Initialize model parallel (using tensor parallel as an entrypoint # to expert parallel) - ensure_model_parallel_initialized(tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -381,8 +382,9 @@ def test_rearrange_expert_weights_no_change(world_size): @worker_fn_wrapper def worker_fn(): - ensure_model_parallel_initialized(tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() @@ -398,10 +400,9 @@ def worker_fn(): redundancy_config = [2] * num_logical_experts # Same indices - no change - indices = create_expert_indices_with_redundancy(num_layers, - num_logical_experts, - total_physical_experts, - redundancy_config) + indices = create_expert_indices_with_redundancy( + num_layers, num_logical_experts, total_physical_experts, + redundancy_config) expert_weights = create_expert_weights(num_layers, num_local_experts, hidden_sizes, ep_rank, device, @@ -444,8 +445,9 @@ def test_rearrange_expert_weights_profile_mode(world_size): @worker_fn_wrapper def worker_fn(): - ensure_model_parallel_initialized(tensor_model_parallel_size=world_size, - pipeline_model_parallel_size=1) + ensure_model_parallel_initialized( + tensor_model_parallel_size=world_size, + pipeline_model_parallel_size=1) ep_group = get_tp_group().cpu_group ep_rank = torch.distributed.get_rank() From 5b1e3543ffd0d7c58b7db6b874fe0c8583dde2a2 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 25 Jun 2025 15:03:54 -0700 Subject: [PATCH 53/57] [Feature] Use `get_node_count` and remove magic number Signed-off-by: Bowen Wang --- vllm/distributed/eplb/eplb_state.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 3af6ed9b37da..2185df865c1f 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -33,7 +33,7 @@ from torch.distributed import all_gather, all_reduce from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_ep_group +from vllm.distributed.parallel_state import get_ep_group, get_node_count from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -384,10 +384,15 @@ def rearrange(self, # TODO(bowen): Treat differently for prefill and decode nodes num_replicas = model.num_physical_experts num_groups = model.num_expert_groups - # TODO(bowen): Remove magic numbers - num_nodes = (ep_group.size() + 7) // 8 + num_nodes = get_node_count() num_gpus = ep_group.size() + if num_gpus % num_nodes != 0: + logger.warning_once( + f"num_gpus % num_nodes != 0, " + "not using hierarchical rearrangement algorithm.\n" + f"{num_gpus=}, {num_nodes=}") + # Get new expert mappings ( new_physical_to_logical_map, @@ -420,7 +425,7 @@ def rearrange(self, torch.cuda.synchronize() time_end = time.perf_counter() logger.info( - "Rearranged experts %s in %.2f seconds.", - "(profile)" if is_profile else "", + "Rearranged experts%sin %.2f seconds.", + " (profile) " if is_profile else " ", time_end - time_start, ) From 495f782bc734360580d0ea9f2e748029e073888c Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Wed, 25 Jun 2025 18:26:21 -0700 Subject: [PATCH 54/57] [Test] Disable `first_k_dense_replace` in `test_initialization` Otherwise, there will not be any experts for DeepSeek models, thus we're unable to get MoE expert information from example MoE layer Signed-off-by: Bowen Wang --- tests/models/test_initialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 54e8cd597bfc..c3d1dcde1fa3 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -37,6 +37,8 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: "num_experts": 2, "num_experts_per_tok": 2, "num_local_experts": 2, + # Otherwise there will not be any expert layers + "first_k_dense_replace": 0, }) if hasattr(hf_config, "vision_config"): From 66fe93f3be7dc880f9193f2100d7c12e2d380106 Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 26 Jun 2025 01:55:10 -0700 Subject: [PATCH 55/57] [Test] Use only 2 experts in `test_initialization` For DeepSeek-V3, on the CI env, if we use full experts this will cause an OOM Signed-off-by: Bowen Wang --- tests/models/test_initialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index c3d1dcde1fa3..c3d2368625b9 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -39,6 +39,8 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: "num_local_experts": 2, # Otherwise there will not be any expert layers "first_k_dense_replace": 0, + # To avoid OOM on DeepSeek-V3 + "n_routed_experts": 2, }) if hasattr(hf_config, "vision_config"): From 3ec903233a5c3b8021a8aa10fbc0c5507bfed0ab Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 26 Jun 2025 07:15:56 -0700 Subject: [PATCH 56/57] [Test] Get at least `n_group` experts in `test_initialization` Otherwise, topk operation during compilation will fail Signed-off-by: Bowen Wang --- tests/models/test_initialization.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index c3d2368625b9..6928a0b0f625 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -31,16 +31,19 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: text_config = hf_config.get_text_config() + # There should be at least one expert per group + num_experts = getattr(text_config, 'n_group', 2) + text_config.update({ "num_layers": 1, "num_hidden_layers": 1, - "num_experts": 2, + "num_experts": num_experts, "num_experts_per_tok": 2, - "num_local_experts": 2, + "num_local_experts": num_experts, # Otherwise there will not be any expert layers "first_k_dense_replace": 0, # To avoid OOM on DeepSeek-V3 - "n_routed_experts": 2, + "n_routed_experts": num_experts, }) if hasattr(hf_config, "vision_config"): From c479d2c1456def25ac04f26f4d480c63049a536c Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 26 Jun 2025 11:30:07 -0700 Subject: [PATCH 57/57] [Test] Allow 2 experts per group in `test_initialization` Since `grouped_topk` will assume top-2 for DeepSeek-V3 Signed-off-by: Bowen Wang --- tests/models/test_initialization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 6928a0b0f625..e56bc925c9c4 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -31,8 +31,9 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: text_config = hf_config.get_text_config() - # There should be at least one expert per group - num_experts = getattr(text_config, 'n_group', 2) + # Ensure at least 2 expert per group + # Since `grouped_topk` assums top-2 + num_experts = getattr(text_config, 'n_group', 1) * 2 text_config.update({ "num_layers": 1,