|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +from typing import Any, Callable, Dict, List, Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +import vllm.envs as envs |
| 9 | +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme |
| 10 | +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( |
| 11 | + OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) |
| 12 | +from vllm.model_executor.parameter import (GroupQuantScaleParameter, |
| 13 | + PackedvLLMParameter) |
| 14 | +from vllm.platforms import current_platform |
| 15 | + |
| 16 | +__all__ = ["QuarkW4A4MXFP4"] |
| 17 | + |
| 18 | + |
| 19 | +class QuarkW4A4MXFP4(QuarkScheme): |
| 20 | + |
| 21 | + def __init__(self, weight_quant_spec: Dict[str, Any], |
| 22 | + input_quant_spec: Dict[str, Any]): |
| 23 | + self.out_dtype = torch.get_default_dtype() |
| 24 | + self.qscheme = "per_group" |
| 25 | + self.weight_quant_spec = weight_quant_spec |
| 26 | + self.input_quant_spec = input_quant_spec |
| 27 | + self.emulate = not current_platform.supports_mx() |
| 28 | + |
| 29 | + @classmethod |
| 30 | + def get_min_capability(cls) -> int: |
| 31 | + return 70 |
| 32 | + |
| 33 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 34 | + layer.weight = torch.nn.Parameter(layer.weight.data, |
| 35 | + requires_grad=False) |
| 36 | + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, |
| 37 | + requires_grad=False) |
| 38 | + |
| 39 | + if self.emulate: |
| 40 | + try: |
| 41 | + from quark.torch.export.nn.modules import realquantizer |
| 42 | + from quark.torch.quantization.config.config import ( |
| 43 | + QuantizationSpec) |
| 44 | + except ImportError as err: |
| 45 | + raise ImportError( |
| 46 | + "The package `amd-quark` is required to use AMD Quark " |
| 47 | + "MX-FP4 models. Please install it with `pip install " |
| 48 | + "amd-quark`.") from err |
| 49 | + |
| 50 | + weight_quant_spec = QuantizationSpec.from_dict( |
| 51 | + self.weight_quant_spec) |
| 52 | + |
| 53 | + weight_quantizer = realquantizer.get_real_quantizer( |
| 54 | + qspec=weight_quant_spec, |
| 55 | + quantizer=None, |
| 56 | + real_quantized=True, |
| 57 | + reorder=False, |
| 58 | + float_dtype=self.out_dtype, |
| 59 | + scale_shape=layer.weight_scale.shape, |
| 60 | + zero_point_shape=None, |
| 61 | + ) |
| 62 | + weight_quantizer.scale.data = layer.weight_scale.data |
| 63 | + |
| 64 | + if not envs.VLLM_QUARK_EMU_MEM_OPT: |
| 65 | + layer.weight = torch.nn.Parameter( |
| 66 | + weight_quantizer(layer.weight.data).to(self.out_dtype), |
| 67 | + requires_grad=False, |
| 68 | + ) |
| 69 | + else: |
| 70 | + self.weight_quantizer = weight_quantizer |
| 71 | + layer.weight_scale = None |
| 72 | + |
| 73 | + # This call is necessary to release the scales memory. |
| 74 | + torch.cuda.empty_cache() |
| 75 | + |
| 76 | + def create_weights(self, layer: torch.nn.Module, |
| 77 | + output_partition_sizes: List[int], |
| 78 | + input_size_per_partition: int, |
| 79 | + params_dtype: torch.dtype, weight_loader: Callable, |
| 80 | + **kwargs): |
| 81 | + output_size_per_partition = sum(output_partition_sizes) |
| 82 | + layer.logical_widths = output_partition_sizes |
| 83 | + |
| 84 | + # WEIGHT |
| 85 | + weight = PackedvLLMParameter( |
| 86 | + data=torch.empty( |
| 87 | + output_size_per_partition, |
| 88 | + input_size_per_partition // 2, |
| 89 | + dtype=torch.uint8, |
| 90 | + ), |
| 91 | + input_dim=1, |
| 92 | + output_dim=0, |
| 93 | + packed_dim=1, |
| 94 | + packed_factor=2, |
| 95 | + weight_loader=weight_loader, |
| 96 | + ) |
| 97 | + layer.register_parameter("weight", weight) |
| 98 | + |
| 99 | + # WEIGHT SCALE |
| 100 | + weight_scale = GroupQuantScaleParameter( |
| 101 | + data=torch.empty( |
| 102 | + output_size_per_partition, |
| 103 | + input_size_per_partition // OCP_MX_BLOCK_SIZE, |
| 104 | + dtype=torch.uint8, |
| 105 | + ), |
| 106 | + input_dim=1, |
| 107 | + output_dim=0, |
| 108 | + weight_loader=weight_loader, |
| 109 | + ) |
| 110 | + layer.register_parameter("weight_scale", weight_scale) |
| 111 | + |
| 112 | + def apply_weights(self, |
| 113 | + layer: torch.nn.Module, |
| 114 | + x: torch.Tensor, |
| 115 | + bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
| 116 | + |
| 117 | + if self.emulate: |
| 118 | + if envs.VLLM_QUARK_EMU_MEM_OPT: |
| 119 | + dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) |
| 120 | + else: |
| 121 | + dq_w = layer.weight |
| 122 | + qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) |
| 123 | + return F.linear(qdq_x, dq_w, bias) |
| 124 | + else: |
| 125 | + raise NotImplementedError() |
0 commit comments