Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
CUDAGraphMode, PassConfig)
from vllm.config.parallel import DistributedExecutorBackend, ParallelConfig
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig)
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger
Expand Down
108 changes: 87 additions & 21 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import torch
from pydantic import model_validator
from pydantic import TypeAdapter, model_validator
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from typing_extensions import Self
Expand All @@ -32,6 +32,38 @@
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]


@config
@dataclass
class EPLBConfig:
"""Configuration for Expert Parallel Load Balancing (EP)."""

window_size: int = 1000
"""Window size for expert load recording."""
step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.

Note that if this is greater than the EPLB window size, only the metrics
of the last `lb_window_size` steps will be used for rearranging experts.
"""

num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""

log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""

@classmethod
def from_cli(cls, cli_value: str) -> "EPLBConfig":
"""Parse the CLI value for the compilation config.
-O1, -O2, -O3, etc. is handled in FlexibleArgumentParser.
"""
return TypeAdapter(EPLBConfig).validate_json(cli_value)


@config
@dataclass
class ParallelConfig:
Expand Down Expand Up @@ -75,22 +107,24 @@ class ParallelConfig:
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
"""Enable expert parallelism load balancing for MoE layers."""
num_redundant_experts: int = 0
"""Number of redundant experts to use for expert parallelism."""
eplb_window_size: int = 1000
"""Window size for expert load recording."""
eplb_step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.

Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""
eplb_log_balancedness: bool = False
"""
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
eplb_config: EPLBConfig = field(default_factory=EPLBConfig)
"""Expert parallelism configuration."""
num_redundant_experts: Optional[int] = None
"""`num_redundant_experts` is deprecated and has been replaced with
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
Please use `eplb_config.num_redundant_experts` instead."""
eplb_window_size: Optional[int] = None
"""`eplb_window_size` is deprecated and has been replaced with
`eplb_config.window_size`. This will be removed in v0.12.0.
Please use `eplb_config.window_size` instead."""
eplb_step_interval: Optional[int] = None
"""`eplb_step_interval` is deprecated and has been replaced with
`eplb_config.step_interval`. This will be removed in v0.12.0.
Please use `eplb_config.step_interval` instead."""
eplb_log_balancedness: Optional[bool] = None
"""`eplb_log_balancedness` is deprecated and has been replaced with
`eplb_config.log_balancedness`. This will be removed in v0.12.0.
Please use `eplb_config.log_balancedness` instead."""

max_parallel_loading_workers: Optional[int] = None
"""Maximum number of parallel loading workers when loading model
Expand Down Expand Up @@ -241,6 +275,38 @@ def compute_hash(self):
return hashlib.sha256(str(factors).encode()).hexdigest()

def __post_init__(self) -> None:
# Forward deprecated fields to their new location
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = (
self.num_redundant_experts)
logger.warning_once(
"num_redundant_experts is deprecated and has been replaced "
"with eplb_config.num_redundant_experts. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
logger.warning_once(
"eplb_window_size is deprecated and has been replaced "
"with eplb_config.window_size. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
logger.warning_once(
"eplb_step_interval is deprecated and has been replaced "
"with eplb_config.step_interval. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness
logger.warning_once(
"eplb_log_balancedness is deprecated and has been replaced "
"with eplb_config.log_balancedness. This will be removed "
"in v0.12.0. Changing this field after initialization will "
"have no effect.")

# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

Expand Down Expand Up @@ -279,10 +345,10 @@ def __post_init__(self) -> None:
raise ValueError(
"Expert parallelism load balancing is only supported on "
"CUDA devices now.")
if self.num_redundant_experts < 0:
if self.eplb_config.num_redundant_experts < 0:
raise ValueError(
"num_redundant_experts must be non-negative, but got "
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if not self.enable_expert_parallel:
raise ValueError(
"enable_expert_parallel must be True to use EPLB.")
Expand All @@ -293,10 +359,10 @@ def __post_init__(self) -> None:
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.num_redundant_experts != 0:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts should be used with EPLB."
f"{self.num_redundant_experts}.")
f"{self.eplb_config.num_redundant_experts}.")
if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
Expand Down
4 changes: 2 additions & 2 deletions vllm/distributed/eplb/eplb_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def build(
dtype=torch.int32,
device=device,
)
expert_load_window_size = parallel_config.eplb_window_size
expert_load_window_size = parallel_config.eplb_config.window_size
expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers,
model.num_physical_experts),
Expand All @@ -253,7 +253,7 @@ def build(
)

# Set the initial progress of rearrangement to 3/4
eplb_step_interval = parallel_config.eplb_step_interval
eplb_step_interval = parallel_config.eplb_config.step_interval
expert_rearrangement_step = max(
0, eplb_step_interval - eplb_step_interval // 4)

Expand Down
63 changes: 46 additions & 17 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, ConvertOption,
DecodingConfig, DetailedTraceModules, Device,
DeviceConfig, DistributedExecutorBackend,
DeviceConfig, DistributedExecutorBackend, EPLBConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, ModelConfig, ModelDType,
Expand Down Expand Up @@ -303,11 +303,12 @@ class EngineArgs:
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config")
enable_eplb: bool = ParallelConfig.enable_eplb
num_redundant_experts: int = ParallelConfig.num_redundant_experts
eplb_window_size: int = ParallelConfig.eplb_window_size
eplb_step_interval: int = ParallelConfig.eplb_step_interval
eplb_log_balancedness: bool = ParallelConfig.eplb_log_balancedness
num_redundant_experts: int = EPLBConfig.num_redundant_experts
eplb_window_size: int = EPLBConfig.window_size
eplb_step_interval: int = EPLBConfig.step_interval
eplb_log_balancedness: bool = EPLBConfig.log_balancedness
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[BlockSize] = CacheConfig.block_size
Expand Down Expand Up @@ -449,6 +450,9 @@ def __post_init__(self):
if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(
**self.compilation_config)
if isinstance(self.eplb_config, dict):
self.eplb_config = EPLBConfig.from_cli(json.dumps(
self.eplb_config))
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
Expand Down Expand Up @@ -647,14 +651,32 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument("--enable-eplb",
**parallel_kwargs["enable_eplb"])
parallel_group.add_argument("--num-redundant-experts",
**parallel_kwargs["num_redundant_experts"])
parallel_group.add_argument("--eplb-window-size",
**parallel_kwargs["eplb_window_size"])
parallel_group.add_argument("--eplb-step-interval",
**parallel_kwargs["eplb_step_interval"])
parallel_group.add_argument("--eplb-log-balancedness",
**parallel_kwargs["eplb_log_balancedness"])
parallel_group.add_argument("--eplb-config",
**parallel_kwargs["eplb_config"])
parallel_group.add_argument(
"--num-redundant-experts",
type=int,
help=
"[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-window-size",
type=int,
help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-step-interval",
type=int,
help=
"[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.",
deprecated=True)
parallel_group.add_argument(
"--eplb-log-balancedness",
action=argparse.BooleanOptionalAction,
help=
"[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.",
deprecated=True)

parallel_group.add_argument(
"--max-parallel-loading-workers",
**parallel_kwargs["max_parallel_loading_workers"])
Expand Down Expand Up @@ -1216,6 +1238,16 @@ def create_engine_config(
"Currently, speculative decoding is not supported with "
"async scheduling.")

# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts
if self.eplb_window_size is not None:
self.eplb_config.window_size = self.eplb_window_size
if self.eplb_step_interval is not None:
self.eplb_config.step_interval = self.eplb_step_interval
if self.eplb_log_balancedness is not None:
self.eplb_config.log_balancedness = self.eplb_log_balancedness

parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
Expand All @@ -1229,10 +1261,7 @@ def create_engine_config(
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,
eplb_window_size=self.eplb_window_size,
eplb_step_interval=self.eplb_step_interval,
eplb_log_balancedness=self.eplb_log_balancedness,
eplb_config=self.eplb_config,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
ray_workers_use_nsight=self.ray_workers_use_nsight,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def __init__(

# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb

self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def __init__(

# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb

self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def __init__(

# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb

self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
Expand Down Expand Up @@ -367,7 +367,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
enable_eplb = parallel_config.enable_eplb
self.num_redundant_experts = parallel_config.num_redundant_experts
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ def eplb_step(self,
model,
is_dummy,
is_profile,
log_stats=self.parallel_config.eplb_log_balancedness,
log_stats=self.parallel_config.eplb_config.log_balancedness,
)

def get_dp_padding(self,
Expand Down Expand Up @@ -1956,7 +1956,7 @@ def load_model(self, eep_scale_up: bool = False) -> None:
global_expert_load, old_global_expert_indices = (
EplbState.recv_state())
num_logical_experts = global_expert_load.shape[1]
self.parallel_config.num_redundant_experts = (
self.parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts)
assert old_global_expert_indices.shape[
1] % num_local_physical_experts == 0
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _reconfigure_moe(self, old_ep_size: int,
assert self.model_runner.eplb_state is not None
new_physical_experts = \
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
parallel_config.num_redundant_experts = (
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts -
self.model_runner.eplb_state.logical_replica_count.shape[1])
global_expert_load = None
Expand All @@ -527,7 +527,7 @@ def _reconfigure_moe(self, old_ep_size: int,
assert self.model_runner.eplb_state is not None
global_expert_load = self.model_runner.eplb_state.rearrange(
self.model_runner.model, execute_shuffle=False)
parallel_config.num_redundant_experts = (
parallel_config.eplb_config.num_redundant_experts = (
new_physical_experts - global_expert_load.shape[1])
prepare_communication_buffer_for_model(self.model_runner.model)
self.model_runner.model.update_physical_experts_metadata(
Expand Down