55import torch
66import torch .nn .functional as F
77
8+ import vllm .envs as envs
89from vllm .model_executor .layers .quantization .quark .schemes import QuarkScheme
10+ from vllm .model_executor .layers .quantization .utils .mxfp4_utils import (
11+ OCP_MX_BLOCK_SIZE )
912from vllm .model_executor .parameter import (GroupQuantScaleParameter ,
1013 PackedvLLMParameter )
1114from vllm .platforms import current_platform
1215
13- from vllm .model_executor .layers .quantization .utils .mxfp4_utils import OCP_MX_BLOCK_SIZE
14-
1516__all__ = ["QuarkW4A4MXFP4" ]
1617
1718
18-
19-
2019class QuarkW4A4MXFP4 (QuarkScheme ):
2120
2221 def __init__ (self , weight_quant_spec : Dict [str , Any ],
@@ -48,7 +47,6 @@ def __init__(self, weight_quant_spec: Dict[str, Any],
4847 float_dtype = self .out_dtype ,
4948 )
5049
51-
5250 @classmethod
5351 def get_min_capability (cls ) -> int :
5452 # lovelace and up
@@ -74,7 +72,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
7472 weight_quant_spec = QuantizationSpec .from_dict (
7573 self .weight_quant_spec )
7674
77- weight_quantizer = realquantizer .get_real_quantizer (
75+ self . weight_quantizer = realquantizer .get_real_quantizer (
7876 qspec = weight_quant_spec ,
7977 quantizer = None ,
8078 real_quantized = True ,
@@ -83,12 +81,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8381 scale_shape = layer .weight_scale .shape ,
8482 zero_point_shape = None ,
8583 )
86- weight_quantizer .scale .data = layer .weight_scale .data
84+ self . weight_quantizer .scale .data = layer .weight_scale .data
8785
88- layer .weight = torch .nn .Parameter (
89- weight_quantizer (layer .weight .data ).to (self .out_dtype ),
90- requires_grad = False ,
91- )
86+ if not envs .VLLM_QUARK_EMU_MEM_OPT :
87+ layer .weight = torch .nn .Parameter (
88+ self .weight_quantizer (layer .weight .data ).to (
89+ self .out_dtype ),
90+ requires_grad = False ,
91+ )
9292
9393 def create_weights (self , layer : torch .nn .Module ,
9494 output_partition_sizes : List [int ],
@@ -132,7 +132,11 @@ def apply_weights(self,
132132 bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
133133
134134 if self .emulate :
135+ if envs .VLLM_QUARK_EMU_MEM_OPT :
136+ dq_w = self .weight_quantizer (layer .weight ).to (self .out_dtype )
137+ else :
138+ dq_w = layer .weight
135139 qdq_x = self .input_quantizer (x )
136- return F .linear (qdq_x , layer . weight , bias )
140+ return F .linear (qdq_x , dq_w , bias )
137141 else :
138142 raise NotImplementedError ()
0 commit comments