diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 04d8e91b0d25..0dc6e46c15be 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -49,6 +49,9 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + is_flashinfer_supporting_global_sf, +) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -1289,6 +1292,7 @@ def __init__( "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, + "global_num_experts": self.global_num_experts, } # need full intermediate size pre-sharding for WNA16 act order if self.quant_method.__class__.__name__ in ( @@ -1632,13 +1636,25 @@ def weight_loader( param.data[:, :dim1, :dim2].copy_(loaded_weight) return True if return_success else None - expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) - if expert_id == -1: + quant_method_name = self.quant_method.__class__.__name__ + global_expert_id = expert_id + expert_id = self._map_global_expert_id_to_local_expert_id(global_expert_id) + + allow_flashinfer = getattr(self.quant_method, "allow_flashinfer", False) + moe_backend = getattr(self.quant_method, "flashinfer_moe_backend", None) + + use_global_sf = ( + allow_flashinfer + and is_flashinfer_supporting_global_sf(moe_backend) + and "input_scale" in weight_name + and quant_method_name == "ModelOptNvFp4FusedMoE" + ) + + if expert_id == -1 and not use_global_sf: # Failed to load this param since it's not local to this rank return False if return_success else None # Hereafter, `expert_id` is local physical id - quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -1723,7 +1739,9 @@ def weight_loader( ) self._load_single_value( - param=param, loaded_weight=loaded_weight, expert_id=expert_id + param=param, + loaded_weight=loaded_weight, + expert_id=global_expert_id if use_global_sf else expert_id, ) return True if return_success else None diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 9d496f72eb3f..3eeb42d22ae0 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -49,6 +49,7 @@ build_flashinfer_fp8_cutlass_moe_prepare_finalize, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, + is_flashinfer_supporting_global_sf, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, @@ -1217,6 +1218,7 @@ def create_weights( weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn weight_loader = extra_weight_attrs.get("weight_loader") + global_num_experts = extra_weight_attrs.get("global_num_experts") # GEMM 1 w13_weight = ModelWeightParameter( data=torch.empty( @@ -1295,14 +1297,19 @@ def create_weights( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) + use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( + self.flashinfer_moe_backend + ) + global_scale_num_experts = global_num_experts if use_global_sf else num_experts + w13_input_scale = PerTensorScaleParameter( - data=torch.empty(num_experts, 2, dtype=torch.float32), + data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( - data=torch.empty(num_experts, dtype=torch.float32), + data=torch.empty(global_scale_num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale) @@ -1457,7 +1464,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas - w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) + use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( + self.flashinfer_moe_backend + ) + if use_global_sf: + # For backends provide by Flashinfer, the input global scales are + # shared across all experts. + w13_input_scale = ( + layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) + ) + else: + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False, @@ -1469,14 +1486,22 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # GEMM 2 processing + if use_global_sf: + # For backends provide by Flashinfer, the input global scales are + # shared across all experts. + w2_input_scale = ( + layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) + ) + else: + w2_input_scale = layer.w2_input_scale layer.g2_alphas = Parameter( - (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( - (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False + (1 / w2_input_scale).to(torch.float32), requires_grad=False ) # TensorRT-LLM specific processing diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 8fce7235bdde..50ea049c3d5a 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -263,3 +263,9 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" f" expected one of {allowed_backends}" ) + + +def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> bool: + # TODO(shuw@nvidia): Update when new backends are added. + backends_supporting_global_sf = (FlashinferMoeBackend.CUTLASS,) + return backend in backends_supporting_global_sf