77
88import vllm .envs as envs
99from vllm .model_executor .layers .quantization .quark .schemes import QuarkScheme
10- from vllm .model_executor .layers .quantization .utils .mxfp4_utils import (
11- OCP_MX_BLOCK_SIZE )
1210from vllm .model_executor .parameter import (GroupQuantScaleParameter ,
1311 PackedvLLMParameter )
1412from vllm .platforms import current_platform
1513
14+ from vllm .model_executor .layers .quantization .utils .mxfp4_utils import OCP_MX_BLOCK_SIZE , per_token_group_quant_mxfp4
15+
1616__all__ = ["QuarkW4A4MXFP4" ]
1717
1818
@@ -37,16 +37,6 @@ def __init__(self, weight_quant_spec: Dict[str, Any],
3737 "MX-FP4 models. Please install it with `pip install "
3838 "amd-quark`." ) from err
3939
40- input_quant_spec = QuantizationSpec .from_dict (
41- self .input_quant_spec )
42-
43- self .input_quantizer = realquantizer .get_real_quantizer (
44- qspec = input_quant_spec ,
45- quantizer = None ,
46- real_quantized = False ,
47- float_dtype = self .out_dtype ,
48- )
49-
5040 @classmethod
5141 def get_min_capability (cls ) -> int :
5242 # lovelace and up
@@ -72,7 +62,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
7262 weight_quant_spec = QuantizationSpec .from_dict (
7363 self .weight_quant_spec )
7464
75- self . weight_quantizer = realquantizer .get_real_quantizer (
65+ weight_quantizer = realquantizer .get_real_quantizer (
7666 qspec = weight_quant_spec ,
7767 quantizer = None ,
7868 real_quantized = True ,
@@ -81,14 +71,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8171 scale_shape = layer .weight_scale .shape ,
8272 zero_point_shape = None ,
8373 )
84- self . weight_quantizer .scale .data = layer .weight_scale .data
74+ weight_quantizer .scale .data = layer .weight_scale .data
8575
8676 if not envs .VLLM_QUARK_EMU_MEM_OPT :
8777 layer .weight = torch .nn .Parameter (
88- self . weight_quantizer (layer .weight .data ).to (
78+ weight_quantizer (layer .weight .data ).to (
8979 self .out_dtype ),
9080 requires_grad = False ,
9181 )
82+ layer .weight_scale = None
83+
84+ # This call is necessary to release the scales memory.
85+ torch .cuda .empty_cache ()
9286
9387 def create_weights (self , layer : torch .nn .Module ,
9488 output_partition_sizes : List [int ],
@@ -136,7 +130,7 @@ def apply_weights(self,
136130 dq_w = self .weight_quantizer (layer .weight ).to (self .out_dtype )
137131 else :
138132 dq_w = layer .weight
139- qdq_x = self . input_quantizer ( x )
133+ qdq_x , _ = per_token_group_quant_mxfp4 ( x , 32 )
140134 return F .linear (qdq_x , dq_w , bias )
141135 else :
142136 raise NotImplementedError ()
0 commit comments