From 4318ff24854d352869df92e419e7694769619917 Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Fri, 18 Apr 2025 10:18:51 +0200 Subject: [PATCH 1/6] MXFP4 wip wip & debug update cleanup use quark realquantizer for pack/quant/dequant comment on cudagraph issue; remove prints Keep only 1 place importing quark cudagraph issue resolved; dq weight at load time for efficiency Signed-off-by: Bowen Bao lint Signed-off-by: Bowen Bao turn on emulation based on platform Signed-off-by: Bowen Bao add fused moe support - ugly wip running version Add envar if dequant weight at load time Signed-off-by: Bowen Bao Mxfp4 memory leak fixes (#2) Signed-off-by: Felix Marty --- vllm/envs.py | 9 + .../layers/fused_moe/fused_moe.py | 24 ++- vllm/model_executor/layers/fused_moe/layer.py | 2 + .../layers/quantization/quark/quark.py | 60 +++++- .../layers/quantization/quark/quark_moe.py | 188 +++++++++++++++++- .../quantization/quark/schemes/__init__.py | 3 +- .../quark/schemes/quark_w4a4_mxfp4.py | 136 +++++++++++++ .../layers/quantization/utils/mxfp4_utils.py | 38 ++++ vllm/model_executor/model_loader/utils.py | 2 +- vllm/platforms/interface.py | 7 + vllm/platforms/rocm.py | 5 + vllm/worker/model_runner.py | 2 +- 12 files changed, 467 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py create mode 100644 vllm/model_executor/layers/quantization/utils/mxfp4_utils.py diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5..c8bb39ceb7b2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -84,6 +84,7 @@ VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_QUARK_EMU_MEM_OPT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -583,6 +584,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # If set, when running in Quark emulation mode, do not dequantize the + # weights at load time. Instead, dequantize weights on-the-fly during + # kernel execution. + # This allows running larger models at the cost of slower inference. + # This flag has no effect when not running in Quark emulation mode. + "VLLM_QUARK_EMU_MEM_OPT": + lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))), + # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a209715ede77..a888b5beff0d 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -20,9 +20,11 @@ per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import per_token_group_quant_mxfp4 from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled logger = init_logger(__name__) @@ -973,6 +975,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -985,7 +988,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, + use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -1003,6 +1006,7 @@ def inplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1037,6 +1041,7 @@ def outplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1050,7 +1055,7 @@ def outplace_fused_experts( return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, - use_int4_w4a16, per_channel_quant, + use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -1067,6 +1072,7 @@ def outplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1120,6 +1126,7 @@ def fused_experts(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1162,6 +1169,7 @@ def fused_experts(hidden_states: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1183,6 +1191,7 @@ def moe_kernel_prepare_input( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + use_mxfp4_w4a4: bool, per_channel_quant: bool, block_shape: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -1220,6 +1229,11 @@ def moe_kernel_prepare_input( elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None assert block_shape is None or block_shape[0] == 0 + elif use_mxfp4_w4a4: + # We assume B (the weight) to be fake quantized - so only handling the activation here. + assert block_shape is None + A, A_scale = per_token_group_quant_mxfp4(A, OCP_MX_BLOCK_SIZE) + else: assert A_scale is None assert B_scale is None @@ -1239,6 +1253,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1350,13 +1365,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - + invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, @@ -1396,6 +1412,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, block_shape=block_shape) @@ -1443,6 +1460,7 @@ def fused_moe( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3cdf3c97a7d3..0dcd147295e5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -581,6 +581,7 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) @@ -592,6 +593,7 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) def _load_w2(self, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index da2312190084..e59d93852e06 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -5,6 +5,7 @@ import torch +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -15,13 +16,15 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 QuarkMoEMethod) from vllm.model_executor.layers.quantization.quark.schemes import ( - QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) + QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8) from vllm.model_executor.layers.quantization.quark.utils import ( deep_compare, should_ignore_layer) from vllm.platforms import current_platform __all__ = ["QuarkLinearMethod"] +logger = init_logger(__name__) + class QuarkConfig(QuantizationConfig): @@ -67,6 +70,9 @@ def get_quant_method(self, layer: torch.nn.Module, return QuarkLinearMethod(self) if isinstance(layer, Attention): return QuarkKVCacheMethod(self) + + # TODO: mixtral defined in mixtral_quant.py does not use FusedMoE, so probably + # `QuarkMoEMethod` was never actually used? if isinstance(layer, FusedMoE): return QuarkMoEMethod.get_moe_method(self, module=layer, @@ -170,7 +176,7 @@ def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], is_static_weight = not weight_quant.get("is_dynamic") is_per_tensor_or_channel_weight = (weight_quant.get("qscheme") in ["per_tensor", "per_channel"]) - + if not (is_fp8_dtype and is_static_weight and is_per_tensor_or_channel_weight): return False @@ -205,6 +211,54 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]], + input_quant: Optional[Dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + logger.debug("Quark model is not in MX-FP4 format: " + "weight_quant or input_quant not set") + return False + + # Input and weight dtype needs to be fp4. + if weight_quant.get("dtype") != "fp4" or input_quant.get( + "dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + return False + + # Input and weight qscheme needs to be per group. + if weight_quant.get("qscheme") != "per_group" or input_quant.get( + "qscheme") != "per_group": + logger.debug("Quark model is not in MX-FP4 format: not per_group") + return False + + # Input and weight group size needs to be 32. + if weight_quant.get("group_size") != 32 or input_quant.get( + "group_size") != 32: + logger.debug( + "Quark model is not in MX-FP4 format: not group_size=32") + return False + + # Weights need to use static quantization. + if weight_quant.get("is_dynamic") is True: + logger.debug( + "Quark model is not in MX-FP4 format: not weight static") + return False + + # Activations need to use dynamic quantization. + if input_quant.get("is_dynamic") is False: + logger.debug( + "Quark model is not in MX-FP4 format: not activation dynamic") + return False + + # Activations and weight scales need to be in e8m0 format. + if weight_quant.get("scale_format") != "e8m0" or input_quant.get( + "scale_format") != "e8m0": + logger.debug( + "Quark model is not in MX-FP4 format: not scale_format e8m0") + return False + + return True + def _find_matched_config(self, layer_name: str, module: torch.nn.Module) -> Dict[str, Any]: @@ -269,6 +323,8 @@ def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme": return QuarkW8A8Int8(qscheme=weight_qscheme, is_static_input_scheme=True, input_symmetric=input_config.get("symmetric")) + elif self._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFP4(weight_config, input_config) raise NotImplementedError("No quark compatible scheme was found. " f"Weight config: {weight_config}, " diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index d1146c0f039d..345e08271914 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -13,10 +13,11 @@ all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE logger = init_logger(__name__) -__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] +__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod"] class QuarkMoEMethod(FusedMoEMethodBase): @@ -39,6 +40,8 @@ def get_moe_method( if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + elif quant_config._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -234,3 +237,186 @@ def apply( w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale) + + +class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): + + def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str, + Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_group" + and input_qscheme == "per_group"): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}") # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + params_dtype = torch.uint8 + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + + print("set w13_weight", w13_weight.shape, w13_weight.dtype) + + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + + print("set w2_weight", w2_weight.shape, w2_weight.dtype) + + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + print("set w2_weight_scale", w2_weight_scale.shape) + print("set w13_weight_scale", w13_weight_scale.shape) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + float_dtype = torch.get_default_dtype() + + try: + from quark.torch.export.nn.modules import realquantizer + from quark.torch.quantization.config.config import ( + QuantizationSpec) + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use AMD Quark " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + weight_quant_spec = QuantizationSpec.from_dict(self.weight_quant) + + # Unpack and dequantize the weights (the operators are in high-precision, with simulated quantization). + w13_quantizer = realquantizer.get_real_quantizer( + qspec=weight_quant_spec, + quantizer=None, + real_quantized=True, + reorder=False, # TODO: load from config + float_dtype=float_dtype, + scale_shape=layer.w13_weight_scale.shape, + zero_point_shape=None, + ) + w13_quantizer.scale.data = layer.w13_weight_scale.data + + layer.w13_weight = torch.nn.Parameter( + w13_quantizer(layer.w13_weight.data).to(float_dtype), + requires_grad=False, + ) + layer.w13_weight_scale = None + + w2_quantizer = realquantizer.get_real_quantizer( + qspec=weight_quant_spec, + quantizer=None, + real_quantized=True, + reorder=False, # TODO: load from config + float_dtype=float_dtype, + scale_shape=layer.w2_weight_scale.shape, + zero_point_shape=None, + ) + w2_quantizer.scale.data = layer.w2_weight_scale.data + + layer.w2_weight = torch.nn.Parameter( + w2_quantizer(layer.w2_weight.data).to(float_dtype), + requires_grad=False, + ) + layer.w2_weight_scale = None + + # This call is necessary to release the scales memory. + torch.cuda.empty_cache() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_mxfp4_w4a4=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=None, + w2_scale=None, + a1_scale=None, + a2_scale=None, + block_shape=None) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index 9069b5a0d515..d7dac17574ff 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from .quark_scheme import QuarkScheme +from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4 from .quark_w8a8_fp8 import QuarkW8A8Fp8 from .quark_w8a8_int8 import QuarkW8A8Int8 -__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"] +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py new file mode 100644 index 000000000000..c7f12888b3bf --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.platforms import current_platform + +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4 + +__all__ = ["QuarkW4A4MXFP4"] + + +class QuarkW4A4MXFP4(QuarkScheme): + + def __init__(self, weight_quant_spec: Dict[str, Any], + input_quant_spec: Dict[str, Any]): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + self.emulate = not current_platform.supports_mx() + + if self.emulate: + try: + from quark.torch.export.nn.modules import realquantizer + from quark.torch.quantization.config.config import ( + QuantizationSpec) + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use AMD Quark " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + + if self.emulate: + try: + from quark.torch.export.nn.modules import realquantizer + from quark.torch.quantization.config.config import ( + QuantizationSpec) + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use AMD Quark " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + weight_quant_spec = QuantizationSpec.from_dict( + self.weight_quant_spec) + + weight_quantizer = realquantizer.get_real_quantizer( + qspec=weight_quant_spec, + quantizer=None, + real_quantized=True, + reorder=False, # TODO: load from config + float_dtype=self.out_dtype, + scale_shape=layer.weight_scale.shape, + zero_point_shape=None, + ) + weight_quantizer.scale.data = layer.weight_scale.data + + if not envs.VLLM_QUARK_EMU_MEM_OPT: + layer.weight = torch.nn.Parameter( + weight_quantizer(layer.weight.data).to( + self.out_dtype), + requires_grad=False, + ) + layer.weight_scale = None + + # This call is necessary to release the scales memory. + torch.cuda.empty_cache() + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=2, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.emulate: + if envs.VLLM_QUARK_EMU_MEM_OPT: + dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) + else: + dq_w = layer.weight + qdq_x, _ = per_token_group_quant_mxfp4(x, 32) + return F.linear(qdq_x, dq_w, bias) + else: + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py new file mode 100644 index 000000000000..d091b6f84f84 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -0,0 +1,38 @@ +import torch + +OCP_MX_BLOCK_SIZE = 32 + +def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int): + try: + from quark.torch.quantization.utils import even_round + from quark.torch.kernel import scaled_fake_quantize + from quark.torch.quantization.utils import reshape_to_blocks + except ImportError as e: + raise ImportError(f"The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`. Error: {e}") + + axis = -1 + block_x = reshape_to_blocks(x, block_k, axis) + amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) + amax = amax.squeeze(-1) + + # TODO: there are other rounding strategies supported in quark and in the config.json that we do not check for here! + scale = even_round(amax, "fp4") + + # Apply dequantize(quantize(x)). + x = scaled_fake_quantize( + "fp4", + x, + scale.to(x.device), + None, + axis, + block_k, + -1., # TODO: useless, to make cleaner + 1., # TODO: useless, to make cleaner + 0, # TODO: useless, to make cleaner + "per_group", + 'None', # must be a string in quark hw_emulation_interface.py, why? + ) + + return x, scale \ No newline at end of file diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 0ca6b6fd88b6..2dbafe18a6a3 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -91,7 +91,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ - "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" ] if (model_config.quantization is not None diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c5555aba1a3e..3f7adc05a85c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -339,6 +339,13 @@ def get_device_communicator_cls(cls) -> str: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def supports_mx(cls) -> bool: + """ + Returns whether the current platform supports MX types. + """ + return False + @classmethod def supports_fp8(cls) -> bool: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index de097ab9af1b..6fb20fa01e47 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -314,6 +314,11 @@ def get_current_memory_usage(cls, def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + @classmethod + def supports_mx(cls) -> bool: + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + return any(gfx in gcn_arch for gfx in ["gfx95"]) + @classmethod def supports_fp8(cls) -> bool: gcn_arch = torch.cuda.get_device_properties(0).gcnArchName diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 73e0eff9a8b7..79dea3d4977c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1725,7 +1725,7 @@ def execute_model( ]) else: model_executable = self.model - + # Receive KV cache in distributed KV cache transfer setting # In disagg prefill setting, it will also recv hidden states and bypass # model forwarding From c439a139790cafaecb54e8dea123b78049f1442b Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 30 Apr 2025 20:30:24 +0000 Subject: [PATCH 2/6] Separate moe to another PR Signed-off-by: Bowen Bao --- .../layers/fused_moe/fused_moe.py | 24 +-- vllm/model_executor/layers/fused_moe/layer.py | 2 - .../layers/quantization/quark/quark_moe.py | 188 +----------------- vllm/worker/model_runner.py | 2 +- 4 files changed, 5 insertions(+), 211 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index a888b5beff0d..a209715ede77 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -20,11 +20,9 @@ per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import per_token_group_quant_mxfp4 from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled logger = init_logger(__name__) @@ -975,7 +973,6 @@ def inplace_fused_experts(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -988,7 +985,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, - use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, + use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -1006,7 +1003,6 @@ def inplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1041,7 +1037,6 @@ def outplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1055,7 +1050,7 @@ def outplace_fused_experts( return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, - use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant, + use_int4_w4a16, per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) @@ -1072,7 +1067,6 @@ def outplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1126,7 +1120,6 @@ def fused_experts(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1169,7 +1162,6 @@ def fused_experts(hidden_states: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1191,7 +1183,6 @@ def moe_kernel_prepare_input( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, - use_mxfp4_w4a4: bool, per_channel_quant: bool, block_shape: Optional[List[int]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -1229,11 +1220,6 @@ def moe_kernel_prepare_input( elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None assert block_shape is None or block_shape[0] == 0 - elif use_mxfp4_w4a4: - # We assume B (the weight) to be fake quantized - so only handling the activation here. - assert block_shape is None - A, A_scale = per_token_group_quant_mxfp4(A, OCP_MX_BLOCK_SIZE) - else: assert A_scale is None assert B_scale is None @@ -1253,7 +1239,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1365,14 +1350,13 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, block_shape=block_shape) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - + invoke_fused_moe_kernel(qcurr_hidden_states, w1, intermediate_cache1, @@ -1412,7 +1396,6 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, - use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, block_shape=block_shape) @@ -1460,7 +1443,6 @@ def fused_moe( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, - use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0dcd147295e5..3cdf3c97a7d3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -581,7 +581,6 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - shard_size = expert_data.shape[shard_dim] // 2 loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) @@ -593,7 +592,6 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - expert_data.copy_(loaded_weight) def _load_w2(self, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 345e08271914..d1146c0f039d 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -13,11 +13,10 @@ all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE logger = init_logger(__name__) -__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod"] +__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] class QuarkMoEMethod(FusedMoEMethodBase): @@ -40,8 +39,6 @@ def get_moe_method( if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config) - elif quant_config._is_mx_fp4(weight_config, input_config): - return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -237,186 +234,3 @@ def apply( w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale) - - -class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): - - def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str, - Any]): - self.weight_quant = weight_config - self.input_quant = input_config - - weight_qscheme = self.weight_quant.get("qscheme") - input_qscheme = self.input_quant.get("qscheme") - if not (weight_qscheme == "per_group" - and input_qscheme == "per_group"): - raise ValueError( - "For MX(FP4) Fused MoE layers, only per-group scales " - "for weights and activations are supported. Found " - f"{weight_qscheme}, {input_qscheme}") # noqa E501 - - self.static_input_scales = not self.input_quant.get("is_dynamic") - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) - - params_dtype = torch.uint8 - - # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // 2, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - - print("set w13_weight", w13_weight.shape, w13_weight.dtype) - - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // 2, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - - print("set w2_weight", w2_weight.shape, w2_weight.dtype) - - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // OCP_MX_BLOCK_SIZE, - dtype=params_dtype, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - hidden_size, - intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - - print("set w2_weight_scale", w2_weight_scale.shape) - print("set w13_weight_scale", w13_weight_scale.shape) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - float_dtype = torch.get_default_dtype() - - try: - from quark.torch.export.nn.modules import realquantizer - from quark.torch.quantization.config.config import ( - QuantizationSpec) - except ImportError as err: - raise ImportError( - "The package `amd-quark` is required to use AMD Quark " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err - - weight_quant_spec = QuantizationSpec.from_dict(self.weight_quant) - - # Unpack and dequantize the weights (the operators are in high-precision, with simulated quantization). - w13_quantizer = realquantizer.get_real_quantizer( - qspec=weight_quant_spec, - quantizer=None, - real_quantized=True, - reorder=False, # TODO: load from config - float_dtype=float_dtype, - scale_shape=layer.w13_weight_scale.shape, - zero_point_shape=None, - ) - w13_quantizer.scale.data = layer.w13_weight_scale.data - - layer.w13_weight = torch.nn.Parameter( - w13_quantizer(layer.w13_weight.data).to(float_dtype), - requires_grad=False, - ) - layer.w13_weight_scale = None - - w2_quantizer = realquantizer.get_real_quantizer( - qspec=weight_quant_spec, - quantizer=None, - real_quantized=True, - reorder=False, # TODO: load from config - float_dtype=float_dtype, - scale_shape=layer.w2_weight_scale.shape, - zero_point_shape=None, - ) - w2_quantizer.scale.data = layer.w2_weight_scale.data - - layer.w2_weight = torch.nn.Parameter( - w2_quantizer(layer.w2_weight.data).to(float_dtype), - requires_grad=False, - ) - layer.w2_weight_scale = None - - # This call is necessary to release the scales memory. - torch.cuda.empty_cache() - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_mxfp4_w4a4=True, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - w1_scale=None, - w2_scale=None, - a1_scale=None, - a2_scale=None, - block_shape=None) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 79dea3d4977c..73e0eff9a8b7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1725,7 +1725,7 @@ def execute_model( ]) else: model_executable = self.model - + # Receive KV cache in distributed KV cache transfer setting # In disagg prefill setting, it will also recv hidden states and bypass # model forwarding From edc4980b9b83456dd0c8d8c3e2230a71e852f7a3 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 30 Apr 2025 20:46:41 +0000 Subject: [PATCH 3/6] Fix VLLM_QUARK_EMU_MEM_OPT codepath Signed-off-by: Bowen Bao --- .../quark/schemes/quark_w4a4_mxfp4.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index c7f12888b3bf..77d34c82bcf9 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -7,12 +7,12 @@ import vllm.envs as envs from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) from vllm.platforms import current_platform -from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4 - __all__ = ["QuarkW4A4MXFP4"] @@ -26,17 +26,6 @@ def __init__(self, weight_quant_spec: Dict[str, Any], self.input_quant_spec = input_quant_spec self.emulate = not current_platform.supports_mx() - if self.emulate: - try: - from quark.torch.export.nn.modules import realquantizer - from quark.torch.quantization.config.config import ( - QuantizationSpec) - except ImportError as err: - raise ImportError( - "The package `amd-quark` is required to use AMD Quark " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err - @classmethod def get_min_capability(cls) -> int: # lovelace and up @@ -75,12 +64,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if not envs.VLLM_QUARK_EMU_MEM_OPT: layer.weight = torch.nn.Parameter( - weight_quantizer(layer.weight.data).to( - self.out_dtype), + weight_quantizer(layer.weight.data).to(self.out_dtype), requires_grad=False, ) + else: + self.weight_quantizer = weight_quantizer layer.weight_scale = None - + # This call is necessary to release the scales memory. torch.cuda.empty_cache() From 2f72aa956f941297a768ad61aa6a33b52e8c1d2e Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 30 Apr 2025 20:51:48 +0000 Subject: [PATCH 4/6] lint Signed-off-by: Bowen Bao --- .../layers/quantization/quark/quark.py | 4 +--- .../layers/quantization/utils/mxfp4_utils.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index e59d93852e06..66e677f56ffd 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -71,8 +71,6 @@ def get_quant_method(self, layer: torch.nn.Module, if isinstance(layer, Attention): return QuarkKVCacheMethod(self) - # TODO: mixtral defined in mixtral_quant.py does not use FusedMoE, so probably - # `QuarkMoEMethod` was never actually used? if isinstance(layer, FusedMoE): return QuarkMoEMethod.get_moe_method(self, module=layer, @@ -176,7 +174,7 @@ def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], is_static_weight = not weight_quant.get("is_dynamic") is_per_tensor_or_channel_weight = (weight_quant.get("qscheme") in ["per_tensor", "per_channel"]) - + if not (is_fp8_dtype and is_static_weight and is_per_tensor_or_channel_weight): return False diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index d091b6f84f84..6be14802a93c 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,23 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 import torch OCP_MX_BLOCK_SIZE = 32 + def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int): try: - from quark.torch.quantization.utils import even_round from quark.torch.kernel import scaled_fake_quantize - from quark.torch.quantization.utils import reshape_to_blocks - except ImportError as e: - raise ImportError(f"The package `amd-quark` is required to use " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`. Error: {e}") + from quark.torch.quantization.utils import (even_round, + reshape_to_blocks) + except ImportError as err: + raise ImportError("The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err axis = -1 block_x = reshape_to_blocks(x, block_k, axis) amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) amax = amax.squeeze(-1) - # TODO: there are other rounding strategies supported in quark and in the config.json that we do not check for here! + # TODO: there are other rounding strategies supported in quark and in the + # config.json that we do not check for here! scale = even_round(amax, "fp4") # Apply dequantize(quantize(x)). @@ -35,4 +38,4 @@ def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int): 'None', # must be a string in quark hw_emulation_interface.py, why? ) - return x, scale \ No newline at end of file + return x, scale From e1a9b91a50a83c82d3b9738e4b1fcf869d2f9d17 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Wed, 30 Apr 2025 20:54:28 +0000 Subject: [PATCH 5/6] Relax device requirement due to emulation Signed-off-by: Bowen Bao --- .../layers/quantization/quark/schemes/quark_w4a4_mxfp4.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 77d34c82bcf9..e27f0eaa0d56 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -28,8 +28,7 @@ def __init__(self, weight_quant_spec: Dict[str, Any], @classmethod def get_min_capability(cls) -> int: - # lovelace and up - return 89 + return 70 def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, From 108a802e0936f2b6226b2cebae3037c039d5d305 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Thu, 1 May 2025 16:45:40 +0000 Subject: [PATCH 6/6] Update to comments Signed-off-by: Bowen Bao Add test Signed-off-by: Bowen Bao revert rope local fix Signed-off-by: Bowen Bao remove print Signed-off-by: Bowen Bao rename scale calculation mode Signed-off-by: Bowen Bao --- tests/models/quantization/test_mxfp4.py | 40 +++++++++++++++++++ .../quark/schemes/quark_w4a4_mxfp4.py | 4 +- .../layers/quantization/utils/mxfp4_utils.py | 28 +++++++------ 3 files changed, 58 insertions(+), 14 deletions(-) create mode 100644 tests/models/quantization/test_mxfp4.py diff --git a/tests/models/quantization/test_mxfp4.py b/tests/models/quantization/test_mxfp4.py new file mode 100644 index 000000000000..9a060829525e --- /dev/null +++ b/tests/models/quantization/test_mxfp4.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# flake8: noqa +"""Tests Quark mxfp4 models against ground truth generation +""" +import pytest + +from vllm import LLM, SamplingParams + +MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"] + +EXPECTED_STRS_MAP = { + "amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [ + '\n### Key Features\n\n* **High-throughput Inference**: vLL', + '\nArtificial intelligence (AI) has evolved significantly since its inception in the 1', + 'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been', + 'A neural network is a machine learning model inspired by the structure of the human brain. It consists of', + '\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol', + '\nThe COVID-19 pandemic has had a profound impact on global economic structures and business', + 'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th', + " everybody knows this proverbial saying, but did you know that it's not entirely accurate?", + ] +} + + +@pytest.mark.skip(reason="Model to be released in the future") +@pytest.mark.quant_model +@pytest.mark.parametrize("model_name", MODELS) +def test_models(example_prompts, model_name) -> None: + sampling_params = SamplingParams(max_tokens=20, temperature=0) + llm = LLM( + model=model_name, + kv_cache_dtype="fp8", + quantization="quark", + ) + outputs = llm.generate(example_prompts, sampling_params) + for i, output in enumerate(outputs): + output_str = output.outputs[0].text + expected_str = EXPECTED_STRS_MAP[model_name][i] + assert expected_str == output_str, ( + f"Expected: {expected_str!r}\nvLLM: {output_str!r}") diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index e27f0eaa0d56..9da52a732fc4 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -54,7 +54,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: qspec=weight_quant_spec, quantizer=None, real_quantized=True, - reorder=False, # TODO: load from config + reorder=False, float_dtype=self.out_dtype, scale_shape=layer.weight_scale.shape, zero_point_shape=None, @@ -119,7 +119,7 @@ def apply_weights(self, dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) else: dq_w = layer.weight - qdq_x, _ = per_token_group_quant_mxfp4(x, 32) + qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) return F.linear(qdq_x, dq_w, bias) else: raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 6be14802a93c..6312c3934fd4 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Tuple + import torch OCP_MX_BLOCK_SIZE = 32 -def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int): +def per_token_group_quant_mxfp4(x: torch.Tensor, + block_k: int, + scale_calculation_mode: str = "even" + ) -> Tuple[torch.Tensor, torch.Tensor]: try: - from quark.torch.kernel import scaled_fake_quantize + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_fp4_fp6_per_group_with_scale) from quark.torch.quantization.utils import (even_round, reshape_to_blocks) except ImportError as err: @@ -21,21 +27,19 @@ def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int): # TODO: there are other rounding strategies supported in quark and in the # config.json that we do not check for here! + if scale_calculation_mode != "even": + raise NotImplementedError( + f"Scale calculation mode {scale_calculation_mode} is not yet " + "supported in MX-FP4 quantization") scale = even_round(amax, "fp4") # Apply dequantize(quantize(x)). - x = scaled_fake_quantize( - "fp4", + x = fake_quantize_fp4_fp6_per_group_with_scale( x, scale.to(x.device), - None, - axis, - block_k, - -1., # TODO: useless, to make cleaner - 1., # TODO: useless, to make cleaner - 0, # TODO: useless, to make cleaner - "per_group", - 'None', # must be a string in quark hw_emulation_interface.py, why? + axis=axis, + group_size=block_k, + quant_dtype="fp4", ) return x, scale