diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 22b3c477f420..fa254030a271 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1921,9 +1921,20 @@ def apply( logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet." - ) + if expert_load_view is None: + raise ValueError("enable_eplb=True requiere expert_load_view != None") + if logical_to_physical_map is None: + raise ValueError( + "enable_eplb=True requiere logical_to_physical_map != None" + ) + if logical_replica_count is None: + raise ValueError( + "enable_eplb=True requiere logical_replica_count != None" + ) + if not isinstance(layer, FusedMoE): + raise TypeError( + "EPLB is only supported when `layer` is a instance of FusedMoE." + ) from vllm.model_executor.layers.fused_moe import fused_experts @@ -1940,6 +1951,12 @@ def apply( routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype, + num_fused_shared_experts=getattr(layer, "num_fused_shared_experts", 0), + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) return fused_experts( @@ -1956,6 +1973,10 @@ def apply( quant_config=self.moe_quant_config, ) + @property + def supports_eplb(self) -> bool: + return True + class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): """ diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 5c3205faf9c2..e2c129120b1a 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -15,7 +15,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -29,7 +29,9 @@ from itertools import islice import torch -from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig +from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import ( + Qwen3VLMoeConfig, +) from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -44,7 +46,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors -from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from .interfaces import MixtureOfExperts +from .qwen3_moe import ( + Qwen3MoeForCausalLM, + Qwen3MoeModel, + Qwen3MoeSparseMoeBlock, +) from .qwen3_vl import ( Qwen3_VisionTransformer, Qwen3VLDummyInputsBuilder, @@ -344,12 +351,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) +class Qwen3VLMoeMixtureOfExperts(MixtureOfExperts): + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for layer in self.language_model.model.layers: + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.moe_layers = [] + example_moe = None + for layer in self.language_model.model.layers: + if hasattr(layer, "mlp") and isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No Qwen3Moe layer found in the language_model.") + + # Set MoE hyperparameters + self.num_moe_layers = len(self.moe_layers) + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + @MULTIMODAL_REGISTRY.register_processor( Qwen3VLMultiModalProcessor, info=Qwen3VLMoeProcessingInfo, dummy_inputs=Qwen3VLDummyInputsBuilder, ) -class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): +class Qwen3VLMoeForConditionalGeneration( + Qwen3VLForConditionalGeneration, Qwen3VLMoeMixtureOfExperts +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -413,3 +464,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.deepstack_input_embeds = None self.visual_dim = config.vision_config.out_hidden_size self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + # Set MoE hyperparameters + self.set_moe_parameters()