diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 280ae60c91ff..9d2c0b0a872b 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -33,7 +33,8 @@ PrefixCachingHashAlgo) from vllm.config.compilation import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) -from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig +from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, + ParallelConfig) from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy from vllm.config.utils import ConfigType, config from vllm.logger import init_logger diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index bac1e63800d7..d474426a4818 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch -from pydantic import model_validator +from pydantic import TypeAdapter, model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -32,6 +32,38 @@ DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] +@config +@dataclass +class EPLBConfig: + """Configuration for Expert Parallel Load Balancing (EP).""" + + window_size: int = 1000 + """Window size for expert load recording.""" + step_interval: int = 3000 + """ + Interval for rearranging experts in expert parallelism. + + Note that if this is greater than the EPLB window size, only the metrics + of the last `lb_window_size` steps will be used for rearranging experts. + """ + + num_redundant_experts: int = 0 + """Number of redundant experts to use for expert parallelism.""" + + log_balancedness: bool = False + """ + Log the balancedness each step of expert parallelism. + This is turned off by default since it will cause communication overhead. + """ + + @classmethod + def from_cli(cls, cli_value: str) -> "EPLBConfig": + """Parse the CLI value for the compilation config. + -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. + """ + return TypeAdapter(EPLBConfig).validate_json(cli_value) + + @config @dataclass class ParallelConfig: @@ -75,22 +107,24 @@ class ParallelConfig: """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False """Enable expert parallelism load balancing for MoE layers.""" - num_redundant_experts: int = 0 - """Number of redundant experts to use for expert parallelism.""" - eplb_window_size: int = 1000 - """Window size for expert load recording.""" - eplb_step_interval: int = 3000 - """ - Interval for rearranging experts in expert parallelism. - - Note that if this is greater than the EPLB window size, only the metrics - of the last `eplb_window_size` steps will be used for rearranging experts. - """ - eplb_log_balancedness: bool = False - """ - Log the balancedness each step of expert parallelism. - This is turned off by default since it will cause communication overhead. - """ + eplb_config: EPLBConfig = field(default_factory=EPLBConfig) + """Expert parallelism configuration.""" + num_redundant_experts: Optional[int] = None + """`num_redundant_experts` is deprecated and has been replaced with + `eplb_config.num_redundant_experts`. This will be removed in v0.12.0. + Please use `eplb_config.num_redundant_experts` instead.""" + eplb_window_size: Optional[int] = None + """`eplb_window_size` is deprecated and has been replaced with + `eplb_config.window_size`. This will be removed in v0.12.0. + Please use `eplb_config.window_size` instead.""" + eplb_step_interval: Optional[int] = None + """`eplb_step_interval` is deprecated and has been replaced with + `eplb_config.step_interval`. This will be removed in v0.12.0. + Please use `eplb_config.step_interval` instead.""" + eplb_log_balancedness: Optional[bool] = None + """`eplb_log_balancedness` is deprecated and has been replaced with + `eplb_config.log_balancedness`. This will be removed in v0.12.0. + Please use `eplb_config.log_balancedness` instead.""" max_parallel_loading_workers: Optional[int] = None """Maximum number of parallel loading workers when loading model @@ -241,6 +275,38 @@ def compute_hash(self): return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Forward deprecated fields to their new location + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = ( + self.num_redundant_experts) + logger.warning_once( + "num_redundant_experts is deprecated and has been replaced " + "with eplb_config.num_redundant_experts. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + logger.warning_once( + "eplb_window_size is deprecated and has been replaced " + "with eplb_config.window_size. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + logger.warning_once( + "eplb_step_interval is deprecated and has been replaced " + "with eplb_config.step_interval. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + logger.warning_once( + "eplb_log_balancedness is deprecated and has been replaced " + "with eplb_config.log_balancedness. This will be removed " + "in v0.12.0. Changing this field after initialization will " + "have no effect.") + + # Continue with the rest of the initialization self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size @@ -279,10 +345,10 @@ def __post_init__(self) -> None: raise ValueError( "Expert parallelism load balancing is only supported on " "CUDA devices now.") - if self.num_redundant_experts < 0: + if self.eplb_config.num_redundant_experts < 0: raise ValueError( "num_redundant_experts must be non-negative, but got " - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if not self.enable_expert_parallel: raise ValueError( "enable_expert_parallel must be True to use EPLB.") @@ -293,10 +359,10 @@ def __post_init__(self) -> None: f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." ) else: - if self.num_redundant_experts != 0: + if self.eplb_config.num_redundant_experts != 0: raise ValueError( "num_redundant_experts should be used with EPLB." - f"{self.num_redundant_experts}.") + f"{self.eplb_config.num_redundant_experts}.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 979f2a06cec9..042acf40d67c 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -244,7 +244,7 @@ def build( dtype=torch.int32, device=device, ) - expert_load_window_size = parallel_config.eplb_window_size + expert_load_window_size = parallel_config.eplb_config.window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, model.num_physical_experts), @@ -253,7 +253,7 @@ def build( ) # Set the initial progress of rearrangement to 3/4 - eplb_step_interval = parallel_config.eplb_step_interval + eplb_step_interval = parallel_config.eplb_config.step_interval expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f8af6d36e0c0..b02bcbc6b54d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -24,7 +24,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, ConvertOption, DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, + DeviceConfig, DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, LoRAConfig, MambaDType, ModelConfig, ModelDType, @@ -303,11 +303,12 @@ class EngineArgs: data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel + eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - num_redundant_experts: int = ParallelConfig.num_redundant_experts - eplb_window_size: int = ParallelConfig.eplb_window_size - eplb_step_interval: int = ParallelConfig.eplb_step_interval - eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness + num_redundant_experts: int = EPLBConfig.num_redundant_experts + eplb_window_size: int = EPLBConfig.window_size + eplb_step_interval: int = EPLBConfig.step_interval + eplb_log_balancedness: bool = EPLBConfig.log_balancedness max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers block_size: Optional[BlockSize] = CacheConfig.block_size @@ -449,6 +450,9 @@ def __post_init__(self): if isinstance(self.compilation_config, dict): self.compilation_config = CompilationConfig( **self.compilation_config) + if isinstance(self.eplb_config, dict): + self.eplb_config = EPLBConfig.from_cli(json.dumps( + self.eplb_config)) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() @@ -647,14 +651,32 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["enable_expert_parallel"]) parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--num-redundant-experts", - **parallel_kwargs["num_redundant_experts"]) - parallel_group.add_argument("--eplb-window-size", - **parallel_kwargs["eplb_window_size"]) - parallel_group.add_argument("--eplb-step-interval", - **parallel_kwargs["eplb_step_interval"]) - parallel_group.add_argument("--eplb-log-balancedness", - **parallel_kwargs["eplb_log_balancedness"]) + parallel_group.add_argument("--eplb-config", + **parallel_kwargs["eplb_config"]) + parallel_group.add_argument( + "--num-redundant-experts", + type=int, + help= + "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-window-size", + type=int, + help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-step-interval", + type=int, + help= + "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( + "--eplb-log-balancedness", + action=argparse.BooleanOptionalAction, + help= + "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True) + parallel_group.add_argument( "--max-parallel-loading-workers", **parallel_kwargs["max_parallel_loading_workers"]) @@ -1216,6 +1238,16 @@ def create_engine_config( "Currently, speculative decoding is not supported with " "async scheduling.") + # Forward the deprecated CLI args to the EPLB config. + if self.num_redundant_experts is not None: + self.eplb_config.num_redundant_experts = self.num_redundant_experts + if self.eplb_window_size is not None: + self.eplb_config.window_size = self.eplb_window_size + if self.eplb_step_interval is not None: + self.eplb_config.step_interval = self.eplb_step_interval + if self.eplb_log_balancedness is not None: + self.eplb_config.log_balancedness = self.eplb_log_balancedness + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1229,10 +1261,7 @@ def create_engine_config( data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, - num_redundant_experts=self.num_redundant_experts, - eplb_window_size=self.eplb_window_size, - eplb_step_interval=self.eplb_step_interval, - eplb_log_balancedness=self.eplb_log_balancedness, + eplb_config=self.eplb_config, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f199da135ec7..d56224b4b7b3 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -132,10 +132,10 @@ def __init__( # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index aff491f9596c..fe5e46a99826 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -131,10 +131,10 @@ def __init__( # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_logical_experts = self.n_routed_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 61b16b6a1d2d..d38ee3405ace 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -121,11 +121,11 @@ def __init__( # Load balancing settings. vllm_config = get_current_vllm_config() - parallel_config = vllm_config.parallel_config + eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = enable_eplb self.n_logical_experts = self.n_routed_experts - self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = (self.n_logical_experts + self.n_redundant_experts) self.n_local_physical_experts = self.n_physical_experts // self.ep_size @@ -367,7 +367,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config enable_eplb = parallel_config.enable_eplb - self.num_redundant_experts = parallel_config.num_redundant_experts + eplb_config = parallel_config.eplb_config + self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9460d91c5832..5759795aedb3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1422,7 +1422,7 @@ def eplb_step(self, model, is_dummy, is_profile, - log_stats=self.parallel_config.eplb_log_balancedness, + log_stats=self.parallel_config.eplb_config.log_balancedness, ) def get_dp_padding(self, @@ -1956,7 +1956,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: global_expert_load, old_global_expert_indices = ( EplbState.recv_state()) num_logical_experts = global_expert_load.shape[1] - self.parallel_config.num_redundant_experts = ( + self.parallel_config.eplb_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) assert old_global_expert_indices.shape[ 1] % num_local_physical_experts == 0 diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 04de8d36680a..a546dfd78ff2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -511,7 +511,7 @@ def _reconfigure_moe(self, old_ep_size: int, assert self.model_runner.eplb_state is not None new_physical_experts = \ self.model_runner.eplb_state.physical_to_logical_map.shape[1] - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - self.model_runner.eplb_state.logical_replica_count.shape[1]) global_expert_load = None @@ -527,7 +527,7 @@ def _reconfigure_moe(self, old_ep_size: int, assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=False) - parallel_config.num_redundant_experts = ( + parallel_config.eplb_config.num_redundant_experts = ( new_physical_experts - global_expert_load.shape[1]) prepare_communication_buffer_for_model(self.model_runner.model) self.model_runner.model.update_physical_experts_metadata(