Skip to content
143 changes: 121 additions & 22 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union

import torch
from torch import nn
from torch import Tensor, nn
from transformers import PretrainedConfig

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand All @@ -50,8 +52,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, extract_layer_index,
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -101,6 +103,7 @@ def __init__(
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
Expand All @@ -110,14 +113,29 @@ def __init__(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")

self.ep_group = get_ep_group().device_group
self.ep_size = self.ep_group.size()

vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.n_routed_experts = config.num_experts
self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.enable_eplb = enable_eplb

self.experts = FusedMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts")
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)

self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
Expand Down Expand Up @@ -246,6 +264,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -277,7 +296,8 @@ def __init__(
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
Expand Down Expand Up @@ -323,6 +343,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
enable_eplb = vllm_config.parallel_config.enable_eplb
self.num_redundant_experts = (
vllm_config.parallel_config.num_redundant_experts)

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Expand All @@ -336,7 +359,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
lambda prefix: Qwen3MoeDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=prefix,
enable_eplb=enable_eplb),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -382,7 +406,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
num_experts=self.config.num_experts,
num_redundant_experts=self.num_redundant_experts)

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
Expand Down Expand Up @@ -433,27 +458,38 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

is_expert_weight = True

name_mapped = name.replace(weight_name, param_name)

# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
if name_mapped.endswith(
ignore_suffixes) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
param = params_dict[name_mapped]
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(
ignore_suffixes) and name not in params_dict:
Expand Down Expand Up @@ -482,7 +518,8 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
MixtureOfExperts):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -514,6 +551,68 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# Implement the MixtureOfExperts protocol.
self.expert_weights = []

self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue

assert isinstance(layer, Qwen3MoeDecoderLayer)
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
self.num_moe_layers = len(self.moe_layers)

if example_layer is None:
raise RuntimeError("No Qwen3MoE layer found in model.layers.")

self.num_expert_groups = 1
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_shared_experts = 0
self.num_redundant_experts = example_layer.n_redundant_experts

def set_eplb_state(
self,
expert_load_view: Tensor,
logical_to_physical_map: Tensor,
logical_replica_count: Tensor,
) -> None:
for layer_idx, layer in enumerate(self.moe_layers):
self.expert_weights.append(layer.get_expert_weights())
layer.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
if self.num_local_physical_experts != num_local_physical_experts:
raise ValueError(f"Expected num_local_physical_experts to be "
f"{self.num_local_physical_experts}, got "
f"{num_local_physical_experts}")
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = (num_physical_experts -
self.num_logical_experts)
for layer in self.model.layers:
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down