diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4a898df8f2a3..88e557f1dfb3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -232,8 +232,8 @@ steps: commands: - pytest -v -s distributed/test_eplb_algo.py -- label: EPLB Execution Test # 5min - timeout_in_minutes: 15 +- label: EPLB Execution Test # 10min + timeout_in_minutes: 20 working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -241,6 +241,7 @@ steps: - tests/distributed/test_eplb_execute.py commands: - pytest -v -s distributed/test_eplb_execute.py + - pytest -v -s distributed/test_eplb_spec_decode.py - label: Metrics, Tracing Test # 12min timeout_in_minutes: 20 diff --git a/tests/distributed/test_eplb_spec_decode.py b/tests/distributed/test_eplb_spec_decode.py new file mode 100644 index 000000000000..11e23f128f33 --- /dev/null +++ b/tests/distributed/test_eplb_spec_decode.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import lm_eval +import pytest + +from tests.utils import large_gpu_mark + + +def get_model_args( + model_name: str, + spec_model_name: str, + spec_method: str, + tp_size: int, + model_max_len: int, +) -> dict: + speculative_config = { + "method": spec_method, + "model": spec_model_name, + "num_speculative_tokens": 1, + "max_model_len": model_max_len, + } + + model_args = { + "pretrained": model_name, + "dtype": "auto", + "add_bos_token": True, + "tensor_parallel_size": tp_size, + "gpu_memory_utilization": 0.7, + "speculative_config": speculative_config, + "enable_expert_parallel": True, + "num_redundant_experts": tp_size, + "eplb_window_size": 128, + "eplb_step_interval": 1024, + "eplb_log_balancedness": False, + "enable_eplb": True, + "max_model_len": model_max_len, + } + return model_args + + +@pytest.mark.parametrize( + "model_setup", + [ + pytest.param( + ("mtp", "Qwen/Qwen3-Next-80B-A3B-Instruct", None, 4, 0.86), + marks=large_gpu_mark(min_gb=80), + ), + pytest.param( + ( + "eagle", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + 4, + 0.92, + ), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues"), + ), + ], + ids=["qwen3_next_mtp", "llama4_eagle"], +) +def test_eplb_spec_decode( + monkeypatch: pytest.MonkeyPatch, + model_setup: tuple[str, str, str, int, float], +): + """ + Test the correctness of EPLB speculative decoding with GSM8K dataset. + Applicable to MoE models with mtp or eagle spec decode. + """ + method, model_name, spec_model_name, tp_size, expected_gsm8k_value = model_setup + + TASK = "gsm8k" + FILTER = "exact_match,strict-match" + RTOL = 0.03 + + model_args = get_model_args( + model_name=model_name, + spec_model_name=spec_model_name, + spec_method=method, + tp_size=tp_size, + model_max_len=4096, + ) + + results = lm_eval.simple_evaluate( + model="vllm", + model_args=model_args, + tasks=TASK, + batch_size=64, + num_fewshot=8, + ) + measured_value = results["results"][TASK][FILTER] + assert ( + measured_value - RTOL < expected_gsm8k_value + and measured_value + RTOL > expected_gsm8k_value + ), f"Expected: {expected_gsm8k_value} | Measured: {measured_value}" diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 17716e8a07ac..526d3ceac7b8 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -33,7 +33,7 @@ import torch from torch.distributed import ProcessGroup, all_reduce -from vllm.config import ParallelConfig +from vllm.config import ModelConfig, ParallelConfig from vllm.distributed.parallel_state import ( get_ep_group, get_node_count, @@ -50,7 +50,7 @@ @dataclass -class EplbState: +class EplbModelState: """EPLB metrics.""" physical_to_logical_map: torch.Tensor @@ -130,34 +130,46 @@ class EplbState: See: https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856 """ - expert_load_window_step: int = 0 - """ - Current step in the sliding window. + model_name: str + model: MixtureOfExperts - Different from `expert_rearrangement_step`, each EP rank may have its own - `expert_load_window_step`. - """ - expert_load_window_size: int = 0 + +class EplbState: """ - Size of the expert load sliding window. - This is a constant and is taken from the config. + EplbState of each expert parallel model. Key is the model config hash. """ - expert_rearrangement_step: int = 0 - """ - Steps after last rearrangement. - Will trigger a rearrangement if it exceeds the threshold. + def __init__(self, parallel_config: ParallelConfig, device: torch.device): + self.parallel_config = parallel_config + self.device = device + self.model_states: dict[str, EplbModelState] = {} + """ + Current step in the sliding window. - NOTE: Keep in mind that all EP ranks need to have the same - `expert_rearrangement_step` value to ensure synchronization. - Otherwise, the rearrangement will hang at collective - communication calls. - """ - expert_rearrangement_step_interval: int = 0 - """ - Interval for expert rearrangement steps. - This is a constant and is taken from the config. - """ + Different from `expert_rearrangement_step`, + each EP rank may have its own `expert_load_window_step`. + """ + self.expert_load_window_step: int = 0 + """ + Size of the expert load sliding window. + This is a constant and is taken from the config. + """ + self.expert_load_window_size: int = 0 + """ + Steps after last rearrangement. + Will trigger a rearrangement if it exceeds the threshold. + + NOTE: Keep in mind that all EP ranks need to have the same + `expert_rearrangement_step` value to ensure synchronization. + Otherwise, the rearrangement will hang at collective + communication calls. + """ + self.expert_rearrangement_step: int = 0 + """ + Interval for expert rearrangement steps. + This is a constant and is taken from the config. + """ + self.expert_rearrangement_step_interval: int = 0 @staticmethod def build_initial_global_physical_to_logical_map( @@ -179,26 +191,63 @@ def build_initial_global_physical_to_logical_map( ] return global_physical_to_logical_map - @classmethod - def build( - cls, + def validate_ep_configuration(self, new_model: MixtureOfExperts): + """ + Validate that the expert parallel configuration of + the new model is the same as the existing models. + """ + if len(self.model_states) > 0: + model = next(iter(self.model_states.values())).model + if ( + model.num_routed_experts != new_model.num_routed_experts + or model.num_redundant_experts != new_model.num_redundant_experts + or model.num_physical_experts != new_model.num_physical_experts + or model.num_logical_experts != new_model.num_logical_experts + or model.num_expert_groups != new_model.num_expert_groups + ): + raise RuntimeError( + "Model: {} " + "with config {} " + "{} {} {} {} " + "mismatch with new model {} " + "with config {} " + "{} {} {} {}".format( + type(model), + model.num_routed_experts, + model.num_redundant_experts, + model.num_physical_experts, + model.num_logical_experts, + model.num_expert_groups, + type(new_model), + new_model.num_routed_experts, + new_model.num_redundant_experts, + new_model.num_physical_experts, + new_model.num_logical_experts, + new_model.num_expert_groups, + ) + ) + + def add_model( + self, model: MixtureOfExperts, - device: torch.device, - parallel_config: ParallelConfig, + model_config: ModelConfig, global_expert_load: torch.Tensor | None = None, old_global_expert_indices: torch.Tensor | None = None, rank_mapping: dict[int, int] | None = None, - ) -> "EplbState": + ): """ Build the initial EPLB state. """ - physical_to_logical_map_list = cls.build_initial_global_physical_to_logical_map( - model.num_routed_experts, - model.num_redundant_experts, + self.validate_ep_configuration(model) + physical_to_logical_map_list = ( + EplbState.build_initial_global_physical_to_logical_map( + model.num_routed_experts, + model.num_redundant_experts, + ) ) physical_to_logical_map = torch.tensor( physical_to_logical_map_list, - device=device, + device=self.device, ) # Assuming 8 GPUs per node, this supports up to # (1023 + 1) / 8 = 128 nodes for now. @@ -212,11 +261,11 @@ def build( logical_to_physical_map = torch.full( (model.num_logical_experts, max_slots_per_logical_expert), -1, - device=device, + device=self.device, ) logical_replica_count = torch.zeros( (model.num_logical_experts,), - device=device, + device=self.device, dtype=torch.long, ) @@ -255,18 +304,25 @@ def build( expert_load_pass = torch.zeros( (model.num_moe_layers, model.num_physical_experts), dtype=torch.int32, - device=device, + device=self.device, ) - expert_load_window_size = parallel_config.eplb_config.window_size + self.expert_load_window_size = self.parallel_config.eplb_config.window_size expert_load_window = torch.zeros( - (expert_load_window_size, model.num_moe_layers, model.num_physical_experts), + ( + self.expert_load_window_size, + model.num_moe_layers, + model.num_physical_experts, + ), dtype=torch.int32, - device=device, + device=self.device, ) # Set the initial progress of rearrangement to 3/4 - eplb_step_interval = parallel_config.eplb_config.step_interval - expert_rearrangement_step = max(0, eplb_step_interval - eplb_step_interval // 4) + eplb_step_interval = self.parallel_config.eplb_config.step_interval + self.expert_rearrangement_step = max( + 0, eplb_step_interval - eplb_step_interval // 4 + ) + self.expert_rearrangement_step_interval = eplb_step_interval if global_expert_load is not None: ep_group = get_ep_group().device_group @@ -309,7 +365,7 @@ def build( (0, logical_to_physical_map.shape[-1] - max_physical_slots), value=-1, ) - physical_to_logical_map = new_physical_to_logical_map.to(device) + physical_to_logical_map = new_physical_to_logical_map.to(self.device) logical_to_physical_map.copy_(new_logical_to_physical_map) logical_replica_count.copy_(new_logical_replica_count) @@ -327,22 +383,20 @@ def build( False, rank_mapping, ) - expert_rearrangement_step = 0 + self.expert_rearrangement_step = 0 - return cls( + self.model_states[model_config.compute_hash()] = EplbModelState( physical_to_logical_map, logical_to_physical_map, logical_replica_count, expert_load_pass, expert_load_window, - expert_load_window_size=expert_load_window_size, - expert_rearrangement_step=expert_rearrangement_step, - expert_rearrangement_step_interval=eplb_step_interval, + model_config.model, + model, ) def step( self, - model: MixtureOfExperts, is_dummy: bool = False, is_profile: bool = False, log_stats: bool = False, @@ -351,7 +405,6 @@ def step( Step the EPLB state. Args: - model (MixtureOfExperts): The MoE model. is_dummy (bool): If `True`, this is a dummy step and the load metrics recorded in this forward pass will not count. Defaults to `False`. @@ -369,60 +422,66 @@ def step( """ if is_profile: - self.rearrange(model, is_profile=True) + self.rearrange(is_profile=True) return if is_dummy: # Do not record load metrics for dummy steps - self.expert_load_pass.zero_() + for eplb_model_state in self.model_states.values(): + eplb_model_state.expert_load_pass.zero_() if log_stats: - # total_expert_load_pass: (num_moe_layers, num_physical_experts) - total_expert_load_pass = self.expert_load_pass.clone() - - # Collect load metrics from all ranks + # Sync the expert load pass for each model (main and drafter). + # expert_load_pass: (num_moe_layers, num_physical_experts) + expert_load_pass_list = self._sync_load_pass() ep_group = get_ep_group().device_group - all_reduce(total_expert_load_pass, group=ep_group) - - # num_tokens_per_rank: (num_moe_layers, num_ranks) - num_tokens_per_rank = ( - total_expert_load_pass.reshape( - total_expert_load_pass.shape[0], ep_group.size(), -1 + for expert_load_pass, eplb_model_state in zip( + expert_load_pass_list, self.model_states.values() + ): + # num_tokens_per_rank: (num_moe_layers, num_ranks) + num_tokens_per_rank = ( + expert_load_pass.reshape( + expert_load_pass.shape[0], ep_group.size(), -1 + ) + .sum(dim=-1) + .float() ) - .sum(dim=-1) - .float() - ) - # Compute balancedness ratio: - # for each layer: - # (mean load across ranks) / (max load across ranks) - avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) - max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0) - - # Just to make type checker happy - tokens_tensors: list[float] = torch.stack( - [avg_tokens_tensor, max_tokens_tensor] - ).tolist() - avg_tokens, max_tokens = tokens_tensors - balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 - - if ep_group.rank() == 0: - logger.info( - "EPLB step: avg_tokens=%.2f, max_tokens=%d, balancedness=%.4f", - avg_tokens, - max_tokens, - balancedness, - ) + # Compute balancedness ratio: + # for each layer: + # (mean load across ranks) / (max load across ranks) + avg_tokens_tensor = num_tokens_per_rank.mean(dim=0).sum(dim=0) + max_tokens_tensor = num_tokens_per_rank.max(dim=0).values.sum(dim=0) + + # Just to make type checker happy + tokens_tensors: list[float] = torch.stack( + [avg_tokens_tensor, max_tokens_tensor] + ).tolist() + avg_tokens, max_tokens = tokens_tensors + balancedness = avg_tokens / max_tokens if max_tokens > 0 else 0.0 + + if ep_group.rank() == 0: + logger.info( + "EPLB step: %d for model %s: avg_tokens=%.2f, " + "max_tokens=%d, balancedness=%.4f", + self.expert_rearrangement_step, + eplb_model_state.model_name, + avg_tokens, + max_tokens, + balancedness, + ) # Update the expert load sliding window if not is_dummy: - self.expert_load_window[self.expert_load_window_step] = ( - self.expert_load_pass.clone() - ) + for eplb_model_state in self.model_states.values(): + eplb_model_state.expert_load_window[self.expert_load_window_step] = ( + eplb_model_state.expert_load_pass.clone() + ) + eplb_model_state.expert_load_pass.zero_() + self.expert_load_window_step += 1 if self.expert_load_window_step >= self.expert_load_window_size: self.expert_load_window_step = 0 - self.expert_load_pass.zero_() # Step the expert rearrangement step # Note that even if this is a dummy step, we still increment the @@ -431,18 +490,30 @@ def step( self.expert_rearrangement_step += 1 if self.expert_rearrangement_step >= self.expert_rearrangement_step_interval: self.expert_rearrangement_step = 0 - self.rearrange(model) + self.rearrange() def rearrange( self, - model: MixtureOfExperts, is_profile: bool = False, execute_shuffle: bool = True, - global_expert_load: torch.Tensor | None = None, + global_expert_loads: list[torch.Tensor] | None = None, rank_mapping: dict[int, int] | None = None, ) -> torch.Tensor | None: """ Rearrange the experts according to the current load. + + Args: + is_profile (bool): If `True`, perform a dummy rearrangement. + This is used in `profile_run` to reserve enough memory, + no memory movement will be performed. Default is False. + execute_shuffle (bool): If `True`, execute the shuffle + in elastic expert parallel (EEP). Default is True. + global_expert_loads (list[torch.Tensor] | None): The global expert + loads when scaling is done in EEP. + List of expert loads for the main and drafter + (when spec decode is used) models. + rank_mapping (dict[int, int] | None): The rank mapping + when scaling is done in EEP. """ ep_group = get_ep_group().device_group @@ -455,53 +526,71 @@ def rearrange( time_start = time.perf_counter() logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") - if global_expert_load is None: + if global_expert_loads is None: # Map the physical expert load to global logical experts - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - model.num_moe_layers, - model.num_logical_experts, - dtype=self.expert_load_window.dtype, - device=self.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=self.physical_to_logical_map.unsqueeze(0) - .expand_as(self.expert_load_window) - .long(), - src=self.expert_load_window, - ) - + global_expert_load_windows = [] if not execute_shuffle: - metadata = torch.tensor( - [ - model.num_moe_layers, - model.num_logical_experts, - self.physical_to_logical_map.shape[1], - ], - dtype=torch.int32, - device="cpu", + num_models = torch.tensor( + [len(self.model_states)], dtype=torch.int32, device="cpu" ) torch.distributed.broadcast( - metadata, group=get_ep_group().cpu_group, group_src=0 + num_models, group=get_ep_group().cpu_group, group_src=0 ) - # Perform all-reduce to get the expert load across all ranks - global_expert_load_window = logical_expert_load_window.sum(dim=0) - all_reduce(global_expert_load_window, group=ep_group) + for eplb_model_state in self.model_states.values(): + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + eplb_model_state.model.num_moe_layers, + eplb_model_state.model.num_logical_experts, + dtype=eplb_model_state.expert_load_window.dtype, + device=eplb_model_state.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=eplb_model_state.physical_to_logical_map.unsqueeze(0) + .expand_as(eplb_model_state.expert_load_window) + .long(), + src=eplb_model_state.expert_load_window, + ) + if not execute_shuffle: + metadata = torch.tensor( + [ + eplb_model_state.model.num_moe_layers, + eplb_model_state.model.num_logical_experts, + eplb_model_state.physical_to_logical_map.shape[1], + ], + dtype=torch.int32, + device="cpu", + ) + torch.distributed.broadcast( + metadata, group=get_ep_group().cpu_group, group_src=0 + ) + + global_expert_load_window = logical_expert_load_window.sum(dim=0) + global_expert_load_windows.append(global_expert_load_window) + # Perform all-reduce to get the expert load across all ranks for each model + global_expert_load_windows = self._allreduce_list( + global_expert_load_windows + ) if not execute_shuffle: - # (num_moe_layers, old_num_physical_experts) - old_global_expert_indices = self.physical_to_logical_map - torch.distributed.broadcast( - old_global_expert_indices, group=ep_group, group_src=0 - ) - return global_expert_load_window + for eplb_model_state, global_expert_load_window in zip( + self.model_states.values(), global_expert_load_windows + ): + # (num_moe_layers, old_num_physical_experts) + old_global_expert_indices = eplb_model_state.physical_to_logical_map + torch.distributed.broadcast( + old_global_expert_indices, group=ep_group, group_src=0 + ) + if not execute_shuffle: + return global_expert_load_windows else: assert execute_shuffle - global_expert_load_window = global_expert_load + global_expert_load_windows = global_expert_loads # TODO(bowen): Treat differently for prefill and decode nodes + eplb_model_state = next(iter(self.model_states.values())) + model = eplb_model_state.model num_replicas = model.num_physical_experts num_groups = model.num_expert_groups if rank_mapping is not None and len(rank_mapping) == ep_group.size(): @@ -526,48 +615,64 @@ def rearrange( f"{num_gpus=}, {num_nodes=}" ) - # Get new expert mappings - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = rebalance_experts( - global_expert_load_window, - num_replicas, - num_groups, - num_nodes, - num_gpus, - ) + for eplb_model_state, global_expert_load_window in zip( + self.model_states.values(), global_expert_load_windows + ): + # Get new expert mappings for the model + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = rebalance_experts( + global_expert_load_window, + num_replicas, + num_groups, + num_nodes, + num_gpus, + ) - # Update expert weights - rearrange_expert_weights_inplace( - self.physical_to_logical_map, - new_physical_to_logical_map, - model.expert_weights, - ep_group, - is_profile, - rank_mapping, - ) + # Update expert weights + rearrange_expert_weights_inplace( + eplb_model_state.physical_to_logical_map, + new_physical_to_logical_map, + eplb_model_state.model.expert_weights, + ep_group, + is_profile, + rank_mapping, + ) - if not is_profile: - if ( - self.physical_to_logical_map.shape[1] - != new_physical_to_logical_map.shape[1] - ): - self.physical_to_logical_map = new_physical_to_logical_map.to( - self.physical_to_logical_map.device + if not is_profile: + if ( + eplb_model_state.physical_to_logical_map.shape[1] + != new_physical_to_logical_map.shape[1] + ): + eplb_model_state.physical_to_logical_map = ( + new_physical_to_logical_map.to( + eplb_model_state.physical_to_logical_map.device + ) + ) + else: + eplb_model_state.physical_to_logical_map.copy_( + new_physical_to_logical_map + ) + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert ( + max_physical_slots + <= eplb_model_state.logical_to_physical_map.shape[-1] ) - else: - self.physical_to_logical_map.copy_(new_physical_to_logical_map) - max_physical_slots = new_logical_to_physical_map.shape[-1] - assert max_physical_slots <= self.logical_to_physical_map.shape[-1] - new_logical_to_physical_map = torch.nn.functional.pad( - new_logical_to_physical_map, - (0, self.logical_to_physical_map.shape[-1] - max_physical_slots), - value=-1, - ) - self.logical_to_physical_map.copy_(new_logical_to_physical_map) - self.logical_replica_count.copy_(new_logical_replica_count) + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + ( + 0, + eplb_model_state.logical_to_physical_map.shape[-1] + - max_physical_slots, + ), + value=-1, + ) + eplb_model_state.logical_to_physical_map.copy_( + new_logical_to_physical_map + ) + eplb_model_state.logical_replica_count.copy_(new_logical_replica_count) if is_main_rank: assert time_start is not None @@ -581,32 +686,118 @@ def rearrange( return None @staticmethod - def recv_state() -> tuple[torch.Tensor, torch.Tensor]: + def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]: """ Receive the expert load and old placement from the master rank. """ ep_group = get_ep_group() - metadata = torch.empty(3, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) - num_moe_layers, num_logical_experts, num_old_physical_experts = ( - metadata.tolist() + num_models = torch.empty(1, dtype=torch.int32, device="cpu") + torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0) + num_models = num_models.item() + global_expert_loads = [] + old_global_expert_indices_per_model = [] + for _ in range(num_models): + metadata = torch.empty(3, dtype=torch.int32, device="cpu") + torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) + num_moe_layers, num_logical_experts, num_old_physical_experts = ( + metadata.tolist() + ) + global_expert_load = torch.zeros( + (num_moe_layers, num_logical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + all_reduce(global_expert_load, group=ep_group.device_group) + old_global_expert_indices = torch.empty( + (num_moe_layers, num_old_physical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + torch.distributed.broadcast( + old_global_expert_indices, + group=ep_group.device_group, + group_src=0, + ) + global_expert_loads.append(global_expert_load) + old_global_expert_indices_per_model.append(old_global_expert_indices) + return global_expert_loads, old_global_expert_indices_per_model + + @classmethod + def get_eep_state( + cls, parallel_config: ParallelConfig + ) -> tuple[ + list[torch.Tensor] | None, + list[torch.Tensor] | None, + dict[int, int] | None, + ]: + num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") + torch.distributed.broadcast( + num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0, ) - global_expert_load = torch.zeros( - (num_moe_layers, num_logical_experts), - dtype=torch.int64, - device=ep_group.device, + num_local_physical_experts = int(num_local_physical_experts.item()) + new_ep_size = get_ep_group().world_size + global_expert_loads, old_global_expert_indices_per_model = ( + EplbState.recv_state() ) - all_reduce(global_expert_load, group=ep_group.device_group) - old_global_expert_indices = torch.empty( - (num_moe_layers, num_old_physical_experts), - dtype=torch.int64, - device=ep_group.device, + + # EP configuration for all models has to be the same so as eplb config + num_logical_experts = global_expert_loads[0].shape[1] + parallel_config.eplb_config.num_redundant_experts = ( + num_local_physical_experts * new_ep_size - num_logical_experts ) - torch.distributed.broadcast( - old_global_expert_indices, group=ep_group.device_group, group_src=0 + assert ( + old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts + == 0 + ) + old_ep_size = ( + old_global_expert_indices_per_model[0].shape[1] + // num_local_physical_experts + ) + rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} + return ( + global_expert_loads, + old_global_expert_indices_per_model, + rank_mapping, ) - return global_expert_load, old_global_expert_indices + def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]: + """ + All-reduce a list of tensors. + """ + if len(tensor_list) == 1: + all_reduce(tensor_list[0], group=get_ep_group().device_group) + return tensor_list + assert all(t.dim() == 2 for t in tensor_list), "All tensors must be 2D." + assert all(t.shape[1] == tensor_list[0].shape[1] for t in tensor_list), ( + "All tensors must have the same shape[1]." + ) + # Concatenate, all_reduce, then unpack to original shapes. + # We assume all tensors are 2D and shape[1] (num_physical_experts) + # is the same across all models. + shapes = [t.shape for t in tensor_list] + concat_tensor = torch.cat(tensor_list, dim=0) + + ep_group = get_ep_group().device_group + all_reduce(concat_tensor, group=ep_group) + + all_reduce_list = [] + offset = 0 + for shape in shapes: + all_reduce_list.append(concat_tensor[offset : offset + shape[0], :]) + offset += shape[0] + return all_reduce_list + + def _sync_load_pass(self) -> list[torch.Tensor]: + """ + Sync the expert load pass across all ranks for log stats. + Doesn't update the expert load pass in eplb_model_state. + """ + load_pass_list = [] + for eplb_model_state in self.model_states.values(): + load_pass_list.append(eplb_model_state.expert_load_pass.clone()) + return self._allreduce_list(load_pass_list) def _node_count_with_rank_mapping( diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 8d520f5bf8ef..950139c69c29 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -226,7 +226,7 @@ def _decorator(obj: type[ToolParser]) -> type[ToolParser]: if isinstance(name, str): names = [name] - elif is_list_of(name, str): + elif name is not None and is_list_of(name, str): names = name else: names = [class_name] diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 107b1e1a0582..fd2f20ea501d 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -24,9 +24,12 @@ DeepseekV2DecoderLayer, DeepseekV3ForCausalLM, ) +from vllm.utils import init_logger from .utils import AutoWeightsLoader, maybe_prefix +logger = init_logger(__name__) + @support_torch_compile class DeepseekV2Model(nn.Module): @@ -215,6 +218,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.vocab_size, scale=logit_scale ) + # Set MoE hyperparameters + self.num_moe_layers = self.config.num_hidden_layers + self.set_moe_parameters() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 3984d23970ac..26b9c25e6bdb 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -8,6 +8,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -25,11 +26,15 @@ from .deepseek_v2 import ( DeepseekV2DecoderLayer, + DeepseekV2MixtureOfExperts, + DeepseekV2MoE, get_spec_layer_idx_from_weight_name, ) from .interfaces import SupportsPP from .utils import maybe_prefix +logger = init_logger(__name__) + class SharedHead(nn.Module): def __init__( @@ -119,6 +124,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.mtp_start_layer_idx = config.num_hidden_layers self.num_mtp_layers = config.num_nextn_predict_layers # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict( { str(idx): DeepSeekMultiTokenPredictorLayer( @@ -172,13 +178,33 @@ def compute_logits( @support_torch_compile -class DeepSeekMTP(nn.Module, SupportsPP): +class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config self.model = DeepSeekMultiTokenPredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) + # Set MoE hyperparameters + self.set_moe_parameters() + + def set_moe_parameters(self): + self.expert_weights = [] + self.num_moe_layers = self.config.num_nextn_predict_layers + self.num_expert_groups = self.config.n_group + + self.moe_layers = [] + self.moe_mlp_layers = [] + example_moe = None + for layer in self.model.layers.values(): + assert isinstance(layer, DeepSeekMultiTokenPredictorLayer) + layer = layer.mtp_block + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) + self.moe_layers.append(layer.mlp.experts) + self.extract_moe_parameters(example_moe) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index db7b86ffaf96..a253cdffd901 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -166,7 +166,7 @@ def __init__( self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts @@ -1122,7 +1122,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) else: self.embed_tokens = PPMissingLayer() - self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( @@ -1172,7 +1171,50 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): +class DeepseekV2MixtureOfExperts(MixtureOfExperts): + moe_mlp_layers: list[DeepseekV2MoE] + """ + List of MoE MLP layers in the model. + """ + + def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None): + if example_moe is None: + self.num_moe_layers = 0 + self.num_expert_groups = 0 + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.") + else: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for moe in self.moe_mlp_layers: + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + +class DeepseekV2ForCausalLM( + nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA +): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], } @@ -1213,13 +1255,19 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) + # Set MoE hyperparameters + self.num_moe_layers = ( + self.config.num_hidden_layers - self.config.first_k_dense_replace + ) + self.set_moe_parameters() + + def set_moe_parameters(self): self.expert_weights = [] - # Set MoE hyperparameters - self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace - self.num_expert_groups = config.n_group + self.num_expert_groups = self.config.n_group - self.moe_layers: list[SharedFusedMoE] = [] + self.moe_layers = [] + self.moe_mlp_layers = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -1229,50 +1277,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if isinstance(layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) self.moe_layers.append(layer.mlp.experts) - if example_moe is None: - raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") - - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = num_physical_experts - self.num_logical_experts - for layer in self.model.layers: - if isinstance(layer.mlp, DeepseekV2MoE): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() + self.extract_moe_parameters(example_moe) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index 192ca0585230..b35666175ea7 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -133,7 +133,7 @@ def __init__( self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", None) self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.moe_num_experts self.n_shared_experts: int = self.moe_num_shared_experts @@ -709,22 +709,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_shared_experts = example_moe.n_shared_experts self.num_redundant_experts = example_moe.n_redundant_experts - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index a53f52852c6a..b30bd66161da 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -62,7 +62,7 @@ ) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -127,7 +127,7 @@ def __init__( self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts: int = config.n_routed_experts self.n_shared_experts: int = config.n_shared_experts @@ -616,7 +616,35 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): +class Glm4MixtureOfExperts(MixtureOfExperts): + def extract_moe_parameters(self, example_moe: Glm4MoE | None) -> None: + if example_moe is None: + raise RuntimeError("No Glm4MoE layer found in model.layers.") + else: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for moe in self.moe_mlp_layers: + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + +class Glm4MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, Glm4MixtureOfExperts): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -659,7 +687,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = config.n_group - self.moe_layers: list[SharedFusedMoE] = [] + self.moe_layers = [] + self.moe_mlp_layers: list[Glm4MoE] = [] + example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -669,33 +699,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if isinstance(layer.mlp, Glm4MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) self.moe_layers.append(layer.mlp.experts) - if example_moe is None: - raise RuntimeError("No Glm4MoE layer found in model.layers.") - - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) + self.extract_moe_parameters(example_moe) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 9fb1be7ba45c..9a2ae3c476f0 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -29,7 +29,7 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -41,7 +41,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name +from .glm4_moe import ( + Glm4MixtureOfExperts, + Glm4MoE, + Glm4MoeDecoderLayer, + get_spec_layer_idx_from_weight_name, +) from .interfaces import SupportsPP from .utils import maybe_prefix @@ -73,6 +78,7 @@ def __init__( prefix: str, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, + parallel_config: ParallelConfig | None = None, ) -> None: super().__init__() self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -81,11 +87,13 @@ def __init__( self.shared_head = SharedHead( config=config, prefix=prefix, quant_config=quant_config ) + self.enable_eplb = parallel_config.enable_eplb self.mtp_block = Glm4MoeDecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix, + enable_eplb=self.enable_eplb, ) def forward( @@ -127,6 +135,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): f"{prefix}.layers.{idx}", cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, + parallel_config=vllm_config.parallel_config, ) for idx in range( self.mtp_start_layer_idx, @@ -175,7 +184,7 @@ def compute_logits( return logits -class Glm4MoeMTP(nn.Module, SupportsPP): +class Glm4MoeMTP(nn.Module, SupportsPP, Glm4MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config @@ -183,6 +192,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = self.config.num_nextn_predict_layers + self.num_expert_groups = self.config.n_group + + self.moe_layers: list[FusedMoE] = [] + self.moe_mlp_layers: list[Glm4MoE] = [] + example_moe = None + for layer in self.model.layers.values(): + assert isinstance(layer, Glm4MoeMultiTokenPredictorLayer) + layer = layer.mtp_block + assert isinstance(layer, Glm4MoeDecoderLayer) + if isinstance(layer.mlp, Glm4MoE): + example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) + self.moe_layers.append(layer.mlp.experts) + self.extract_moe_parameters(example_moe) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 901f29310872..8fa9776bd018 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -374,7 +374,7 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts @@ -1007,7 +1007,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set MoE hyperparameters self.expert_weights = [] self.num_expert_groups = 1 - self.moe_layers: list[SharedFusedMoE] = [] + self.moe_layers = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -1028,22 +1028,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - self.expert_weights.append(layer.get_expert_weights()) - # Register the expert weights. - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index e133206c27a8..33c9043405ca 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -14,6 +14,7 @@ import numpy as np import torch +import torch.nn as nn from torch import Tensor from transformers import PretrainedConfig from transformers.models.whisper.tokenization_whisper import LANGUAGES @@ -641,6 +642,9 @@ class MixtureOfExperts(Protocol): num_redundant_experts: int """Number of redundant experts in this model.""" + moe_layers: Iterable[nn.Module] + """List of MoE layers in this model.""" + def set_eplb_state( self, expert_load_view: Tensor, @@ -663,7 +667,15 @@ def set_eplb_state( logical_to_physical_map: Mapping from logical to physical experts. logical_replica_count: Count of replicas for each logical expert. """ - ... + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def update_physical_experts_metadata( self, diff --git a/vllm/model_executor/models/lfm2_moe.py b/vllm/model_executor/models/lfm2_moe.py index bb7926a9cfa9..02a490e9c7fd 100644 --- a/vllm/model_executor/models/lfm2_moe.py +++ b/vllm/model_executor/models/lfm2_moe.py @@ -105,7 +105,7 @@ def __init__( self.routed_scaling_factor = config.routed_scaling_factor self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts @@ -707,7 +707,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # Set MoE hyperparameters self.expert_weights = [] - self.moe_layers: list[FusedMoE] = [] + self.moe_layers = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -737,22 +737,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 33badb13fc9f..a7e0732ec71e 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -30,9 +30,11 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( + get_ep_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -46,6 +48,7 @@ default_weight_loader, maybe_remap_kv_scale_name, ) +from vllm.model_executor.models.interfaces import MixtureOfExperts from vllm.model_executor.models.utils import sequence_parallel_chunk from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel @@ -56,6 +59,8 @@ is_pp_missing_parameter, ) +logger = init_logger(__name__) + class Llama4MoE(nn.Module): @staticmethod @@ -80,6 +85,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): self.tp_size = get_tensor_model_parallel_world_size() self.top_k = config.num_experts_per_tok self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() intermediate_size_moe = config.intermediate_size self.router = ReplicatedLinear( @@ -101,6 +109,20 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): disable_tp=self.is_sequence_parallel, ) + # Load balancing settings. + eplb_config = parallel_config.eplb_config if parallel_config else None + self.enable_eplb = parallel_config.enable_eplb if parallel_config else False + self.n_redundant_experts = ( + eplb_config.num_redundant_experts if eplb_config else 0 + ) + + self.n_routed_experts: int = config.num_local_experts + self.n_logical_experts = self.n_routed_experts + self.n_shared_experts: int = 1 + self.n_local_experts: int = config.num_local_experts + self.n_physical_experts = self.n_local_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + self.experts = SharedFusedMoE( shared_experts=self.shared_expert, num_experts=config.num_local_experts, @@ -114,6 +136,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, prefix=f"{prefix}.experts", is_sequence_parallel=self.is_sequence_parallel, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, ) def forward(self, hidden_states): @@ -378,6 +402,9 @@ def __init__( layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer, ): self.num_experts = vllm_config.model_config.hf_config.num_local_experts + self.n_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts + ) super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) def load_moe_expert_weights( @@ -499,7 +526,6 @@ def load_moe_expert_weights( shard_id=shard_id, expert_id=expert_id, ) - loaded_params.add(full_param_name) expert_param_loaded = True @@ -526,6 +552,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.num_experts, + num_redundant_experts=self.n_redundant_experts, ) # Expert parameter mapping for the case where the expert weights are # fused into a single weight tensor. @@ -683,7 +710,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -class Llama4ForCausalLM(LlamaForCausalLM): +class Llama4ForCausalLM(LlamaForCausalLM, MixtureOfExperts): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -702,6 +729,57 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__( vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer ) + # Set MoE hyperparameters + self.set_moe_parameters() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.moe_layers = [] + example_moe = None + for layer in self.model.layers: + assert isinstance(layer, Llama4DecoderLayer) + if isinstance(layer.feed_forward, Llama4MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.feed_forward + self.moe_layers.append(layer.feed_forward.experts) + + if example_moe is None: + self.num_moe_layers = 0 + self.num_expert_groups = 0 + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + logger.warning("No Llama4MoE layer found in model.layers.") + else: + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.feed_forward, Llama4MoE): + moe = layer.feed_forward + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() def _init_model( self, diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 90273463d64e..b59176191e7a 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -189,6 +189,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.vocab_size, scale=logit_scale ) + # Set MoE hyperparameters + self.set_moe_parameters() + def get_language_model(self) -> torch.nn.Module: return self.model diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 09328b472248..95097a6f832c 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -578,6 +578,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + parallel_config = vllm_config.parallel_config self.prefix = prefix self.vllm_config = vllm_config @@ -613,6 +614,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) + if parallel_config.enable_eplb and getattr(config, "num_experts", 0) > 0: + raise NotImplementedError("EPLB is not supported for MiniCPM yet.") def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index bc56481820a9..c1f411b6cd2a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -98,7 +98,7 @@ def __init__( self.hidden_size = hidden_size self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() # Expert Parallelism Load balancing settings. @@ -546,7 +546,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.expert_weights = [] - self.moe_layers: list[FusedMoE] = [] + self.moe_layers = [] example_moe = None for layer in self.model.layers: @@ -572,22 +572,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_expert_groups = 1 self.num_shared_experts = 0 - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 81be1135dfd9..4548abde77d5 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -65,6 +65,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( + MixtureOfExperts, MultiModalEmbeddings, SupportsEagle3, SupportsMultiModal, @@ -723,7 +724,7 @@ def get_dummy_mm_data( dummy_inputs=Mllama4DummyInputsBuilder, ) class Llama4ForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 + nn.Module, SupportsMultiModal, SupportsPP, MixtureOfExperts, SupportsEagle3 ): merge_by_field_config = True @@ -776,6 +777,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.language_model.make_empty_intermediate_tensors ) + # Set MoE hyperparameters + self.num_expert_groups = 1 + self.num_logical_experts = self.language_model.num_logical_experts + self.num_physical_experts = self.language_model.num_physical_experts + self.num_local_physical_experts = self.language_model.num_local_physical_experts + self.num_routed_experts = self.language_model.num_routed_experts + self.num_shared_experts = self.language_model.num_shared_experts + self.num_redundant_experts = self.language_model.num_redundant_experts + self.moe_layers = self.language_model.moe_layers + self.num_moe_layers = len(self.moe_layers) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: """Set which layers should output auxiliary hidden states for EAGLE3.""" # Delegate to underlying language model (Llama4ForCausalLM) @@ -792,6 +804,24 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") return self.language_model.get_eagle3_aux_hidden_state_layers() + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ): + self.language_model.set_eplb_state( + expert_load_view, logical_to_physical_map, logical_replica_count + ) + self.expert_weights = self.language_model.expert_weights + + def update_physical_experts_metadata( + self, num_physical_experts: int, num_local_physical_experts: int + ): + self.language_model.update_physical_experts_metadata( + num_physical_experts, num_local_physical_experts + ) + def _parse_and_validate_image_input( self, **kwargs: object ) -> Llama4ImagePatchInputs | None: diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 324b63c1732f..fb58d01be7ba 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -807,7 +807,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.expert_weights = [] self.num_expert_groups = config.n_group - self.moe_layers: list[SharedFusedMoE] = [] + self.moe_layers = [] example_moe = None for layer in self.model.layers: if isinstance(layer, NemotronHMoEDecoderLayer): @@ -824,22 +824,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_shared_experts = example_moe.n_shared_experts self.num_redundant_experts = example_moe.n_redundant_experts - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index 457498d995f8..bf1b7570a882 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -1009,7 +1009,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace self.num_expert_groups = 1 - self.moe_layers: list[SharedFusedMoE] = [] + self.moe_layers = [] example_moe = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -1031,22 +1031,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.n_shared_experts = example_moe.n_shared_experts self.num_redundant_experts = example_moe.n_redundant_experts - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 8452d7b04f5c..a7e6772bb708 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -132,7 +132,7 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts @@ -665,7 +665,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set MoE hyperparameters self.expert_weights = [] - self.moe_layers: list[FusedMoE] = [] + self.moe_layers = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): @@ -688,22 +688,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index f452ba871582..e4cd9df2c8dc 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -107,7 +107,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group - self.ep_rank = self.ep_group.rank() + self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts @@ -1095,8 +1095,57 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params +class QwenNextMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.model.layers: + if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.moe_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, Qwen3NextDecoderLayer) and isinstance( + layer.mlp, Qwen3NextSparseMoeBlock + ): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No Qwen3Next layer found in the model.layers.") + + # Set MoE hyperparameters + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + class Qwen3NextForCausalLM( - nn.Module, HasInnerState, SupportsLoRA, SupportsPP, MixtureOfExperts, IsHybrid + nn.Module, + HasInnerState, + SupportsLoRA, + SupportsPP, + QwenNextMixtureOfExperts, + IsHybrid, ): packed_modules_mapping = { "qkv_proj": [ @@ -1147,63 +1196,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) # Set MoE hyperparameters - self.expert_weights = [] - - self.moe_layers: list[SharedFusedMoE] = [] - example_layer = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, Qwen3NextDecoderLayer) - if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): - example_layer = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_layer is None: - raise RuntimeError("No Qwen3Next layer found in the model.layers.") - - self.num_moe_layers = len(self.moe_layers) - self.num_expert_groups = 1 - self.num_shared_experts = 0 - self.num_logical_experts = example_layer.n_logical_experts - self.num_physical_experts = example_layer.n_physical_experts - self.num_local_physical_experts = example_layer.n_local_physical_experts - self.num_routed_experts = example_layer.n_routed_experts - self.num_redundant_experts = example_layer.n_redundant_experts - - def set_eplb_state( - self, - expert_load_view: torch.Tensor, - logical_to_physical_map: torch.Tensor, - logical_replica_count: torch.Tensor, - ) -> None: - for layer_idx, layer in enumerate(self.moe_layers): - # Register the expert weights. - self.expert_weights.append(layer.get_expert_weights()) - layer.set_eplb_state( - moe_layer_idx=layer_idx, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - def update_physical_experts_metadata( - self, - num_physical_experts: int, - num_local_physical_experts: int, - ) -> None: - assert self.num_local_physical_experts == num_local_physical_experts - self.num_physical_experts = num_physical_experts - self.num_local_physical_experts = num_local_physical_experts - self.num_redundant_experts = num_physical_experts - self.num_logical_experts - for layer in self.model.layers: - if isinstance(layer.mlp, Qwen3NextSparseMoeBlock): - moe = layer.mlp - moe.n_local_physical_experts = num_local_physical_experts - moe.n_physical_experts = num_physical_experts - moe.n_redundant_experts = self.num_redundant_experts - moe.experts.update_expert_map() + self.set_moe_parameters() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index a447484ae82a..271b76adcff7 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -23,6 +23,7 @@ from vllm.model_executor.models.qwen3_next import ( Qwen3NextDecoderLayer, Qwen3NextRMSNorm, + QwenNextMixtureOfExperts, ) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig @@ -226,7 +227,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @support_torch_compile -class Qwen3NextMTP(nn.Module, SupportsPP): +class Qwen3NextMTP(nn.Module, SupportsPP, QwenNextMixtureOfExperts): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -265,6 +266,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) + self.set_moe_parameters() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/transformers/moe.py b/vllm/model_executor/models/transformers/moe.py index 2056ebeb1086..8e39eb0b9902 100644 --- a/vllm/model_executor/models/transformers/moe.py +++ b/vllm/model_executor/models/transformers/moe.py @@ -125,7 +125,7 @@ def set_eplb_state( logical_to_physical_map: torch.Tensor, logical_replica_count: torch.Tensor, ): - for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers): + for moe_layer_idx, mlp_layer in enumerate(self.mlp_moe_layers): mlp_layer.experts.set_eplb_state( moe_layer_idx=moe_layer_idx, expert_load_view=expert_load_view, @@ -142,7 +142,7 @@ def update_physical_experts_metadata( self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts - for mlp in self.mlp_layers: + for mlp in self.mlp_moe_layers: mlp.n_local_physical_experts = num_local_physical_experts mlp.n_physical_experts = num_physical_experts mlp.n_redundant_experts = self.num_redundant_experts @@ -240,7 +240,8 @@ def forward(self, *args, **kwargs): # MixtureOfExperts mixin settings ep_size = get_ep_group().world_size - self.mlp_layers = [] # Used for MixtureOfExperts methods + self.mlp_moe_layers = [] # Used for MixtureOfExperts methods + self.moe_layers = [] self.expert_weights = [] self.num_moe_layers = 0 self.num_expert_groups = 1 if num_expert_group is None else num_expert_group @@ -298,7 +299,8 @@ def _recursive_replace(module: nn.Module, prefix: str): mlp.experts = fused_experts log_replacement(qual_name, experts, fused_experts) # Update MixtureOfExperts mixin state - self.mlp_layers.append(mlp) + self.mlp_moe_layers.append(mlp) + self.moe_layers.append(fused_experts) self.expert_weights.append(fused_experts.get_expert_weights()) self.num_moe_layers += 1 # If results are not all-reduced in FusedMoE, ensure they diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 150dde177ce8..12b903ccaca9 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -8,6 +8,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.v1.sample.metadata import SamplingMetadata # Initialize logger @@ -56,6 +57,10 @@ def load_model(self, target_model: nn.Module) -> None: vllm_config=self.vllm_config, model_config=self.vllm_config.speculative_config.draft_model_config, ) + assert not ( + is_mixture_of_experts(self.model) + and self.vllm_config.parallel_config.enable_eplb + ), "EPLB for Medusa is not supported" @torch.inference_mode() def dummy_run(self, num_tokens: int) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 177542ed96c8..d002f6f5b1a8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2046,7 +2046,6 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: model = self.get_model() assert is_mixture_of_experts(model) self.eplb_state.step( - model, is_dummy, is_profile, log_stats=self.parallel_config.eplb_config.log_balancedness, @@ -2803,7 +2802,9 @@ def propose_draft_token_ids( else: indices = [] offset = 0 - assert spec_decode_metadata is not None + assert spec_decode_metadata is not None, ( + "No spec decode metadata for medusa" + ) for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids ): @@ -2934,32 +2935,15 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model_config.model, scope="global", ) - if eep_scale_up: - from vllm.distributed.parallel_state import get_ep_group - - num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast( - num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 - ) - num_local_physical_experts = int(num_local_physical_experts.item()) - new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = EplbState.recv_state() - num_logical_experts = global_expert_load.shape[1] - self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts - ) - assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0 - old_ep_size = ( - old_global_expert_indices.shape[1] // num_local_physical_experts - ) - rank_mapping = { - old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) - } - else: - global_expert_load = None - old_global_expert_indices = None - rank_mapping = None + global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( + EplbState.get_eep_state(self.parallel_config) + if eep_scale_up + else (None, None, None) + ) + if self.parallel_config.enable_eplb: + self.eplb_state = EplbState(self.parallel_config, self.device) + eplb_models = 0 with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) @@ -2971,8 +2955,39 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model, self.vllm_config, self.device ) if hasattr(self, "drafter"): - logger.info("Loading drafter model...") + logger.info_once("Loading drafter model...") self.drafter.load_model(self.model) + if ( + hasattr(self.drafter, "model") + and is_mixture_of_experts(self.drafter.model) + and self.parallel_config.enable_eplb + ): + logger.info_once( + "EPLB is enabled for drafter model %s.", + self.vllm_config.speculative_config.draft_model_config.model, + ) + + global_expert_load = ( + global_expert_loads[eplb_models] + if global_expert_loads + else None + ) + old_global_expert_indices = ( + old_global_expert_indices_per_model[eplb_models] + if old_global_expert_indices_per_model + else None + ) + if self.eplb_state is None: + self.eplb_state = EplbState(self.parallel_config, self.device) + self.eplb_state.add_model( + self.drafter.model, + self.vllm_config.speculative_config.draft_model_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, + ) + eplb_models += 1 + if self.use_aux_hidden_state_outputs: if not supports_eagle3(self.get_model()): raise RuntimeError( @@ -3001,18 +3016,25 @@ def load_model(self, eep_scale_up: bool = False) -> None: scope="local", ) prepare_communication_buffer_for_model(self.model) - self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model()) and self.model_config.multimodal_config.is_multimodal_pruning_enabled() ) if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", self.model_config.model) - self.eplb_state = EplbState.build( + logger.info_once("EPLB is enabled for model %s.", self.model_config.model) + global_expert_load = ( + global_expert_loads[eplb_models] if global_expert_loads else None + ) + old_global_expert_indices = ( + old_global_expert_indices_per_model[eplb_models] + if old_global_expert_indices_per_model + else None + ) + assert self.eplb_state is not None + self.eplb_state.add_model( self.model, - self.device, - self.parallel_config, + self.model_config, global_expert_load, old_global_expert_indices, rank_mapping, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 3cc8f90a3e19..9178d929111c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -32,6 +32,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -613,7 +614,6 @@ def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: } assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=True, global_expert_load=None, rank_mapping=rank_mapping, @@ -626,7 +626,7 @@ def _eplb_after_scale_up( self, old_ep_size: int, new_ep_size: int, - global_expert_load: torch.Tensor | None, + global_expert_loads: list[torch.Tensor] | None, ) -> None: from vllm.distributed.parallel_state import get_ep_group @@ -635,9 +635,8 @@ def _eplb_after_scale_up( rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=True, - global_expert_load=global_expert_load, + global_expert_loads=global_expert_loads, rank_mapping=rank_mapping, ) if get_ep_group().rank == 0: @@ -684,31 +683,56 @@ def _reconfigure_moe( get_ep_group, prepare_communication_buffer_for_model, ) - from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig + from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, + FusedMoEParallelConfig, + ) parallel_config = self.vllm_config.parallel_config - moe_modules = [ - module - for module in self.model_runner.model.modules() - if ( - module.__class__.__name__ == "FusedMoE" - or module.__class__.__name__ == "SharedFusedMoE" - ) - ] - num_local_experts = moe_modules[0].moe_config.num_local_experts - assert all( - module.moe_config.num_local_experts == num_local_experts - for module in moe_modules - ), "All MoE modules must have the same number of experts" - for module in moe_modules: - module.moe_config.num_experts = num_local_experts * new_ep_size - module.global_num_experts = module.moe_config.num_experts - module.moe_parallel_config = FusedMoEParallelConfig.make( - tp_size_=get_tp_group().world_size, - dp_size_=get_dp_group().world_size, - vllm_parallel_config=parallel_config, - ) - module.moe_config.moe_parallel_config = module.moe_parallel_config + + def get_moe_modules(model: torch.nn.Module) -> list[FusedMoE]: + return [ + module + for module in model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) + ] + + def update_moe_modules(moe_modules: list[FusedMoE], num_local_experts: int): + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=get_tp_group().world_size, + dp_size_=get_dp_group().world_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + return moe_modules + + model_moe_modules = get_moe_modules(self.model_runner.model) + num_local_experts = model_moe_modules[0].moe_config.num_local_experts + + update_moe_modules(model_moe_modules, num_local_experts) + drafter_model = None + if hasattr(self.model_runner, "drafter") and hasattr( + self.model_runner.drafter, "model" + ): + drafter_model = self.model_runner.drafter.model + if drafter_model is not None and is_mixture_of_experts(drafter_model): + drafter_moe_modules = get_moe_modules(drafter_model) + # Check if drafter and model have matching configs + assert ( + drafter_moe_modules[0].moe_config.num_local_experts == num_local_experts + ), "Drafter and model configs should be the same" + update_moe_modules(drafter_moe_modules, num_local_experts) + if new_ep_size < old_ep_size: num_local_physical_experts = num_local_experts assert self.model_runner.eplb_state is not None @@ -719,7 +743,7 @@ def _reconfigure_moe( new_physical_experts - self.model_runner.eplb_state.logical_replica_count.shape[1] ) - global_expert_load = None + global_expert_loads = None else: num_local_physical_experts = torch.tensor( [num_local_experts], dtype=torch.int32, device="cpu" @@ -730,18 +754,20 @@ def _reconfigure_moe( num_local_physical_experts = num_local_physical_experts.item() new_physical_experts = num_local_physical_experts * new_ep_size assert self.model_runner.eplb_state is not None - global_expert_load = self.model_runner.eplb_state.rearrange( - self.model_runner.model, execute_shuffle=False + global_expert_loads = self.model_runner.eplb_state.rearrange( + execute_shuffle=False ) parallel_config.eplb_config.num_redundant_experts = ( - new_physical_experts - global_expert_load.shape[1] + new_physical_experts - global_expert_loads[0].shape[1] ) prepare_communication_buffer_for_model(self.model_runner.model) + if drafter_model is not None: + prepare_communication_buffer_for_model(drafter_model) self.model_runner.model.update_physical_experts_metadata( num_physical_experts=new_physical_experts, num_local_physical_experts=num_local_physical_experts, ) - return global_expert_load + return global_expert_loads def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest @@ -782,11 +808,11 @@ def reinitialize_distributed( self.local_rank, ) - global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) + global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size) if new_ep_size > old_ep_size: - assert global_expert_load is not None - self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_load) + assert global_expert_loads is not None + self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads) def save_sharded_state( self,