Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1921,9 +1921,20 @@ def apply(
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet."
)
if expert_load_view is None:
raise ValueError("enable_eplb=True requiere expert_load_view != None")
if logical_to_physical_map is None:
raise ValueError(
"enable_eplb=True requiere logical_to_physical_map != None"
)
if logical_replica_count is None:
raise ValueError(
"enable_eplb=True requiere logical_replica_count != None"
)
if not isinstance(layer, FusedMoE):
raise TypeError(
"EPLB is only supported when `layer` is a instance of FusedMoE."
)

from vllm.model_executor.layers.fused_moe import fused_experts

Expand All @@ -1940,6 +1951,12 @@ def apply(
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0),
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

return fused_experts(
Expand All @@ -1956,6 +1973,10 @@ def apply(
quant_config=self.moe_quant_config,
)

@property
def supports_eplb(self) -> bool:
return True


class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""
Expand Down
62 changes: 58 additions & 4 deletions vllm/model_executor/models/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -29,7 +29,9 @@
from itertools import islice

import torch
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
Qwen3VLMoeConfig,
)

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
Expand All @@ -44,7 +46,12 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors

from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from .interfaces import MixtureOfExperts
from .qwen3_moe import (
Qwen3MoeForCausalLM,
Qwen3MoeModel,
Qwen3MoeSparseMoeBlock,
)
from .qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLDummyInputsBuilder,
Expand Down Expand Up @@ -344,12 +351,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)


class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts):
def update_physical_experts_metadata(
self,
num_physical_experts: int,
num_local_physical_experts: int,
) -> None:
assert self.num_local_physical_experts == num_local_physical_experts
self.num_physical_experts = num_physical_experts
self.num_local_physical_experts = num_local_physical_experts
self.num_redundant_experts = num_physical_experts - self.num_logical_experts
for layer in self.language_model.model.layers:
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
moe = layer.mlp
moe.n_local_physical_experts = num_local_physical_experts
moe.n_physical_experts = num_physical_experts
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()

def set_moe_parameters(self):
self.expert_weights = []

self.moe_layers = []
example_moe = None
for layer in self.language_model.model.layers:
if hasattr(layer, "mlp") and isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)

if example_moe is None:
raise RuntimeError("No Qwen3Moe layer found in the language_model.")

# Set MoE hyperparameters
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts


@MULTIMODAL_REGISTRY.register_processor(
Qwen3VLMultiModalProcessor,
info=Qwen3VLMoeProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
class Qwen3VLMoeForConditionalGeneration(
Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -413,3 +464,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level

# Set MoE hyperparameters
self.set_moe_parameters()