Skip to content

Commit 7e5b7c1

Browse files
committed
lint
Signed-off-by: Bowen Bao <[email protected]>
1 parent c1843c7 commit 7e5b7c1

File tree

4 files changed

+40
-36
lines changed

4 files changed

+40
-36
lines changed

vllm/attention/layer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(
7575
block_size = 16
7676
is_attention_free = False
7777
calculate_kv_scales = False
78-
7978
if num_kv_heads is None:
8079
num_kv_heads = num_heads
8180

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88

9+
from vllm.logger import init_logger
910
from vllm.model_executor.layers.fused_moe import FusedMoE
1011
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1112
UnquantizedLinearMethod)
@@ -15,16 +16,16 @@
1516
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
1617
QuarkMoEMethod)
1718
from vllm.model_executor.layers.quantization.quark.schemes import (
18-
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8, QuarkW4A4MXFP4)
19+
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
1920
from vllm.model_executor.layers.quantization.quark.utils import (
2021
deep_compare, should_ignore_layer)
2122
from vllm.platforms import current_platform
22-
from vllm.logger import init_logger
2323

2424
__all__ = ["QuarkLinearMethod"]
2525

2626
logger = init_logger(__name__)
2727

28+
2829
class QuarkConfig(QuantizationConfig):
2930

3031
def __init__(self,
@@ -201,45 +202,53 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
201202
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
202203

203204
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
204-
input_quant: Optional[Dict[str, Any]]) -> bool:
205+
input_quant: Optional[Dict[str, Any]]) -> bool:
205206
# Confirm weights and input quantized.
206207
if weight_quant is None or input_quant is None:
207-
logger.debug("Quark model is not in MX-FP4 format: weight_quant or input_quant not set")
208+
logger.debug("Quark model is not in MX-FP4 format: "
209+
"weight_quant or input_quant not set")
208210
return False
209211

210212
# Input and weight dtype needs to be fp4.
211-
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
213+
if weight_quant.get("dtype") != "fp4" or input_quant.get(
214+
"dtype") != "fp4":
212215
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
213216
return False
214217

215218
# Input and weight qscheme needs to be per group.
216-
if weight_quant.get("qscheme") != "per_group" or input_quant.get("qscheme") != "per_group":
219+
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
220+
"qscheme") != "per_group":
217221
logger.debug("Quark model is not in MX-FP4 format: not per_group")
218222
return False
219223

220224
# Input and weight group size needs to be 32.
221-
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
222-
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
225+
if weight_quant.get("group_size") != 32 or input_quant.get(
226+
"group_size") != 32:
227+
logger.debug(
228+
"Quark model is not in MX-FP4 format: not group_size=32")
223229
return False
224230

225231
# Weights need to use static quantization.
226232
if weight_quant.get("is_dynamic") is True:
227-
logger.debug("Quark model is not in MX-FP4 format: not weight static")
233+
logger.debug(
234+
"Quark model is not in MX-FP4 format: not weight static")
228235
return False
229236

230237
# Activations need to use dynamic quantization.
231238
if input_quant.get("is_dynamic") is False:
232-
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
239+
logger.debug(
240+
"Quark model is not in MX-FP4 format: not activation dynamic")
233241
return False
234242

235243
# Activations and weight scales need to be in e8m0 format.
236-
if weight_quant.get("scale_format") != "e8m0" or input_quant.get("scale_format") != "e8m0":
237-
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
244+
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
245+
"scale_format") != "e8m0":
246+
logger.debug(
247+
"Quark model is not in MX-FP4 format: not scale_format e8m0")
238248
return False
239249

240250
return True
241251

242-
243252
def _find_matched_config(self, layer_name: str,
244253
module: torch.nn.Module) -> Dict[str, Any]:
245254

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from .quark_scheme import QuarkScheme
4+
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
45
from .quark_w8a8_fp8 import QuarkW8A8Fp8
56
from .quark_w8a8_int8 import QuarkW8A8Int8
6-
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
77

88
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Callable, List, Optional, Dict, Any
3+
from typing import Any, Callable, Dict, List, Optional
44

55
import torch
6-
from torch.nn import Parameter
6+
import torch.nn.functional as F
77

88
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
9-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10-
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
11-
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
12-
ModelWeightParameter,
13-
PerTensorScaleParameter)
14-
from vllm.platforms import current_platform
15-
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
16-
17-
import torch.nn.functional as F
9+
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
10+
PackedvLLMParameter)
1811

19-
__all__ = ["QuarkW8A8Fp8"]
12+
__all__ = ["QuarkW4A4MXFP4"]
2013

2114
OCP_MX_BLOCK_SIZE = 32
2215

16+
2317
class QuarkW4A4MXFP4(QuarkScheme):
2418

25-
def __init__(self, weight_quant_spec: Dict[str, Any], input_quant_spec: Dict[str, Any]):
19+
def __init__(self, weight_quant_spec: Dict[str, Any],
20+
input_quant_spec: Dict[str, Any]):
2621
self.out_dtype = torch.get_default_dtype()
2722
self.qscheme = "per_group"
2823
self.weight_quant_spec = weight_quant_spec
@@ -35,17 +30,18 @@ def get_min_capability(cls) -> int:
3530

3631
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
3732
layer.weight = torch.nn.Parameter(layer.weight.data,
38-
requires_grad=False)
39-
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
4033
requires_grad=False)
34+
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
35+
requires_grad=False)
4136

4237
try:
4338
from quark.torch.export.nn.modules import realquantizer
4439
from quark.torch.quantization.config.config import QuantizationSpec
4540
except ImportError as err:
4641
raise ImportError(
47-
f"The package `amd-quark` is required to use AMD Quark MX-FP4 models. Please install it with `pip install amd-quark`. Error: {err}"
48-
)
42+
"The package `amd-quark` is required to use AMD Quark MX-FP4 "
43+
"models. Please install it with `pip install amd-quark`."
44+
) from err
4945

5046
weight_quant_spec = QuantizationSpec.from_dict(self.weight_quant_spec)
5147
input_quant_spec = QuantizationSpec.from_dict(self.input_quant_spec)
@@ -60,9 +56,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
6056
zero_point_shape=None,
6157
)
6258
self.weight_quantizer.scale.data = layer.weight_scale.data
63-
layer.weight = torch.nn.Parameter(
64-
self.weight_quantizer(layer.weight.data).to(self.out_dtype), requires_grad=False
65-
)
59+
layer.weight = torch.nn.Parameter(self.weight_quantizer(
60+
layer.weight.data).to(self.out_dtype),
61+
requires_grad=False)
6662

6763
self.input_quantizer = realquantizer.get_real_quantizer(
6864
qspec=input_quant_spec,
@@ -90,7 +86,7 @@ def create_weights(self, layer: torch.nn.Module,
9086
output_dim=0,
9187
packed_dim=1,
9288
packed_factor=2,
93-
weight_loader=weight_loader
89+
weight_loader=weight_loader,
9490
)
9591
layer.register_parameter("weight", weight)
9692

0 commit comments

Comments
 (0)