Skip to content
123 changes: 101 additions & 22 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Any, Optional, Union
import typing
from collections.abc import Iterable, Callable
from typing import Optional, Any, Union

import torch
from torch import nn
from torch import nn, Tensor
from transformers import PretrainedConfig

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -50,8 +52,9 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, extract_layer_index,

from .interfaces import SupportsLoRA, SupportsPP, MixtureOfExperts
from .utils import (PPMissingLayer, AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -101,6 +104,7 @@ def __init__(
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -110,14 +114,29 @@ def __init__(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")

self.ep_group = get_ep_group().device_group
self.ep_size = self.ep_group.size()

vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.n_routed_experts = config.num_experts
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.enable_eplb = enable_eplb

self.experts = FusedMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts")
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)

self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
Expand Down Expand Up @@ -246,6 +265,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -277,7 +297,8 @@ def __init__(
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
Expand Down Expand Up @@ -323,6 +344,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.num_redundant_experts = (
vllm_config.parallel_config.num_redundant_experts)

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand All @@ -336,7 +360,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
lambda prefix: Qwen3MoeDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=prefix,
enable_eplb=enable_eplb),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -382,7 +407,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
Expand Down Expand Up @@ -433,27 +459,38 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

is_expert_weight = True

name_mapped = name.replace(weight_name, param_name)

# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
ignore_suffixes) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
param = params_dict[name_mapped]
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
ignore_suffixes) and name not in params_dict:
Expand Down Expand Up @@ -482,7 +519,8 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
class Qwen3MoeForCausalLM(nn.Module, SupportsPP,
SupportsLoRA, MixtureOfExperts):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The Qwen3MoeForCausalLM class needs to implement the update_physical_experts_metadata method as part of the MixtureOfExperts protocol. This method is called by the EPLB scheduler during expert rebalancing, and its absence will lead to a runtime AttributeError.

    def set_eplb_state(
        self,
        expert_load_view: Tensor,
        logical_to_physical_map: Tensor,
        logical_replica_count: Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        for layer in self.model.layers:
            if isinstance(layer, PPMissingLayer):
                continue

            if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
                layer.mlp.n_physical_experts = num_physical_experts
                layer.mlp.n_local_physical_experts = num_local_physical_experts

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -513,6 +551,47 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# Implement the MixtureOfExperts protocol.
self.expert_weights = []

self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue

assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
self.num_moe_layers = len(self.moe_layers)

if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in model.layers.")

self.num_expert_groups = 1
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_shared_experts = 0
self.num_redundant_experts = example_layer.n_redundant_experts

def set_eplb_state(
self,
expert_load_view: Tensor,
logical_to_physical_map: Tensor,
logical_replica_count: Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand Down
Loading