|
6 | 6 |
|
7 | 7 | import torch |
8 | 8 |
|
| 9 | +from vllm.logger import init_logger |
9 | 10 | from vllm.model_executor.layers.fused_moe import FusedMoE |
10 | 11 | from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, |
11 | 12 | UnquantizedLinearMethod) |
|
15 | 16 | from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 |
16 | 17 | QuarkMoEMethod) |
17 | 18 | from vllm.model_executor.layers.quantization.quark.schemes import ( |
18 | | - QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8, QuarkW4A4MXFP4) |
| 19 | + QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8) |
19 | 20 | from vllm.model_executor.layers.quantization.quark.utils import ( |
20 | 21 | deep_compare, should_ignore_layer) |
21 | 22 | from vllm.platforms import current_platform |
22 | | -from vllm.logger import init_logger |
23 | 23 |
|
24 | 24 | __all__ = ["QuarkLinearMethod"] |
25 | 25 |
|
26 | 26 | logger = init_logger(__name__) |
27 | 27 |
|
| 28 | + |
28 | 29 | class QuarkConfig(QuantizationConfig): |
29 | 30 |
|
30 | 31 | def __init__(self, |
@@ -201,45 +202,53 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], |
201 | 202 | return is_int8_dtype and is_tensor and is_weight_symmetric and is_static |
202 | 203 |
|
203 | 204 | def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]], |
204 | | - input_quant: Optional[Dict[str, Any]]) -> bool: |
| 205 | + input_quant: Optional[Dict[str, Any]]) -> bool: |
205 | 206 | # Confirm weights and input quantized. |
206 | 207 | if weight_quant is None or input_quant is None: |
207 | | - logger.debug("Quark model is not in MX-FP4 format: weight_quant or input_quant not set") |
| 208 | + logger.debug("Quark model is not in MX-FP4 format: " |
| 209 | + "weight_quant or input_quant not set") |
208 | 210 | return False |
209 | 211 |
|
210 | 212 | # Input and weight dtype needs to be fp4. |
211 | | - if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4": |
| 213 | + if weight_quant.get("dtype") != "fp4" or input_quant.get( |
| 214 | + "dtype") != "fp4": |
212 | 215 | logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") |
213 | 216 | return False |
214 | 217 |
|
215 | 218 | # Input and weight qscheme needs to be per group. |
216 | | - if weight_quant.get("qscheme") != "per_group" or input_quant.get("qscheme") != "per_group": |
| 219 | + if weight_quant.get("qscheme") != "per_group" or input_quant.get( |
| 220 | + "qscheme") != "per_group": |
217 | 221 | logger.debug("Quark model is not in MX-FP4 format: not per_group") |
218 | 222 | return False |
219 | 223 |
|
220 | 224 | # Input and weight group size needs to be 32. |
221 | | - if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: |
222 | | - logger.debug("Quark model is not in MX-FP4 format: not group_size=32") |
| 225 | + if weight_quant.get("group_size") != 32 or input_quant.get( |
| 226 | + "group_size") != 32: |
| 227 | + logger.debug( |
| 228 | + "Quark model is not in MX-FP4 format: not group_size=32") |
223 | 229 | return False |
224 | 230 |
|
225 | 231 | # Weights need to use static quantization. |
226 | 232 | if weight_quant.get("is_dynamic") is True: |
227 | | - logger.debug("Quark model is not in MX-FP4 format: not weight static") |
| 233 | + logger.debug( |
| 234 | + "Quark model is not in MX-FP4 format: not weight static") |
228 | 235 | return False |
229 | 236 |
|
230 | 237 | # Activations need to use dynamic quantization. |
231 | 238 | if input_quant.get("is_dynamic") is False: |
232 | | - logger.debug("Quark model is not in MX-FP4 format: not activation dynamic") |
| 239 | + logger.debug( |
| 240 | + "Quark model is not in MX-FP4 format: not activation dynamic") |
233 | 241 | return False |
234 | 242 |
|
235 | 243 | # Activations and weight scales need to be in e8m0 format. |
236 | | - if weight_quant.get("scale_format") != "e8m0" or input_quant.get("scale_format") != "e8m0": |
237 | | - logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0") |
| 244 | + if weight_quant.get("scale_format") != "e8m0" or input_quant.get( |
| 245 | + "scale_format") != "e8m0": |
| 246 | + logger.debug( |
| 247 | + "Quark model is not in MX-FP4 format: not scale_format e8m0") |
238 | 248 | return False |
239 | 249 |
|
240 | 250 | return True |
241 | 251 |
|
242 | | - |
243 | 252 | def _find_matched_config(self, layer_name: str, |
244 | 253 | module: torch.nn.Module) -> Dict[str, Any]: |
245 | 254 |
|
|
0 commit comments