diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b72f51aa52bf..88750adff9b2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -884,32 +884,6 @@ def make_expert_params_mapping( ] ] - def _load_fp8_scale(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor, weight_name: str, - shard_id: str, expert_id: int) -> None: - param_data = param.data - - # Input scales can be loaded directly and should be equal. - if "input_scale" in weight_name: - if param_data[expert_id] != 1 and (param_data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param_data[expert_id]} " - f"vs. {loaded_weight}") - param_data[expert_id] = loaded_weight - # Weight scales - elif "weight_scale" in weight_name: - # If we are in merged column case (gate_up_proj) - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) - else: - param_data[expert_id] = loaded_weight - def extra_repr(self) -> str: s = ( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 2e14845ff2d6..bf32bee89e89 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -268,14 +268,23 @@ def __init__( self.input_quant = self.quant_config.target_scheme_map["Linear"].get( "input_activations") - if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy == QuantizationStrategy.TENSOR): + per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy + == QuantizationStrategy.TENSOR) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not (per_tensor or per_channel): raise ValueError( - "For FP8 Fused MoE layers, only per-tensor scales " - "for weights and activations are supported. Found " + "For FP8 Fused MoE layers, we require per tensor " + "or channelwise, dynamic per token quantization. Found " f"{self.weight_quant}, {self.input_quant}") self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization.") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -303,24 +312,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They are combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, 2, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-TENSOR quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: @@ -362,6 +387,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR if (layer.w13_input_scale is None or layer.w2_input_scale is None): raise ValueError( "QuantConfig has static quantization, but found " @@ -377,24 +403,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) + # For Per-TENSOR case, Fp8 moe kernel needs single weight scale + # for w13 per expert. Use max then dequant and requant each expert. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) def apply( self,