Skip to content

Commit 0eb0757

Browse files
authored
[Misc] Add ignored layers for fp8 quantization (#6657)
1 parent 38c4b7e commit 0eb0757

File tree

4 files changed

+57
-47
lines changed

4 files changed

+57
-47
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from pydantic import BaseModel, Field
66
from torch.nn import Module
77

8+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
9+
FUSED_LAYER_NAME_MAPPING)
10+
811

912
class CompressionFormat(Enum):
1013
dense = "dense"
@@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
8689
return format in _ACTIVATION_QUANTIZATION_FORMATS
8790

8891

89-
# fused_name: List[shard_name]
90-
_FUSED_LAYER_NAME_MAPPING = {
91-
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
92-
"gate_up_proj": ["gate_proj", "up_proj"]
93-
}
94-
95-
9692
def should_ignore_layer(layer_name: Optional[str],
9793
ignore: Iterable[str]) -> bool:
9894
if layer_name is None:
@@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
106102
# in the safetensors checkpoint. So, we convert the name
107103
# from the fused version to unfused + check to make sure that
108104
# each shard of the fused layer has the same scheme.
109-
if proj_name in _FUSED_LAYER_NAME_MAPPING:
110-
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
105+
if proj_name in FUSED_LAYER_NAME_MAPPING:
106+
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
111107

112108
# Convert fused_name --> [shard_names]
113109
shard_names = [

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,15 @@
1111
QuantizationConfig, QuantizeMethodBase)
1212
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
1313
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
14+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15+
is_layer_skipped)
1416
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1517
apply_fp8_linear, create_per_channel_scale_param)
1618
from vllm.model_executor.utils import set_weight_attrs
1719
from vllm.platforms import current_platform
1820

1921
logger = init_logger(__name__)
2022

21-
# Note: this is a hack. We should update each model to register the
22-
# stacked params and get it from there instead in a future PR.
23-
# fused_name: List[shard_name]
24-
_FUSED_LAYER_NAME_MAPPING = {
25-
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
26-
"gate_up_proj": ["gate_proj", "up_proj"]
27-
}
28-
2923

3024
class FBGEMMFp8Config(QuantizationConfig):
3125
"""Config class for FBGEMM Fp8."""
@@ -62,37 +56,10 @@ def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
6256
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
6357
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
6458

65-
def _is_layer_skipped(self, prefix: str) -> bool:
66-
# prefix: model.layers.0.self_attn.q_proj
67-
# proj_name: q_proj
68-
proj_name = prefix.split(".")[-1]
69-
if proj_name in _FUSED_LAYER_NAME_MAPPING:
70-
shard_prefixes = [
71-
prefix.replace(proj_name, shard_proj_name)
72-
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
73-
]
74-
75-
is_skipped = None
76-
for shard_prefix in shard_prefixes:
77-
is_shard_skipped = shard_prefix in self.ignore_list
78-
79-
if is_skipped is None:
80-
is_skipped = is_shard_skipped
81-
elif is_shard_skipped != is_skipped:
82-
raise ValueError(
83-
f"Detected some but not all shards of {prefix} "
84-
"are quantized. All shards of fused layers "
85-
"to have the same precision.")
86-
else:
87-
is_skipped = prefix in self.ignore_list
88-
89-
assert is_skipped is not None
90-
return is_skipped
91-
9259
def get_quant_method(self, layer: torch.nn.Module,
9360
prefix: str) -> Optional["QuantizeMethodBase"]:
9461
if isinstance(layer, LinearBase):
95-
if self._is_layer_skipped(prefix):
62+
if is_layer_skipped(prefix, self.ignore_list):
9663
return UnquantizedLinearMethod()
9764
return FBGEMMFp8LinearMethod(self)
9865
return None

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
from vllm.logger import init_logger
99
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
1010
fused_moe)
11-
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
11+
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
12+
UnquantizedLinearMethod)
1213
from vllm.model_executor.layers.quantization.base_config import (
1314
QuantizationConfig, QuantizeMethodBase)
1415
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
1516
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
1617
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
18+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
19+
is_layer_skipped)
1720
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1821
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
1922
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
@@ -33,6 +36,7 @@ def __init__(
3336
self,
3437
is_checkpoint_fp8_serialized: bool = False,
3538
activation_scheme: str = "dynamic",
39+
ignored_layers: Optional[List[str]] = None,
3640
) -> None:
3741
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
3842
if is_checkpoint_fp8_serialized:
@@ -42,6 +46,7 @@ def __init__(
4246
raise ValueError(
4347
f"Unsupported activation scheme {activation_scheme}")
4448
self.activation_scheme = activation_scheme
49+
self.ignored_layers = ignored_layers or []
4550

4651
@classmethod
4752
def get_name(cls) -> str:
@@ -64,14 +69,18 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
6469
quant_method = cls.get_from_keys(config, ["quant_method"])
6570
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
6671
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
72+
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
6773
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
68-
activation_scheme=activation_scheme)
74+
activation_scheme=activation_scheme,
75+
ignored_layers=ignored_layers)
6976

7077
def get_quant_method(self, layer: torch.nn.Module,
7178
prefix: str) -> Optional["QuantizeMethodBase"]:
7279
from vllm.attention.layer import Attention # Avoid circular import
7380

7481
if isinstance(layer, LinearBase):
82+
if is_layer_skipped(prefix, self.ignored_layers):
83+
return UnquantizedLinearMethod()
7584
return Fp8LinearMethod(self)
7685
elif isinstance(layer, FusedMoE):
7786
return Fp8MoEMethod(self)

vllm/model_executor/layers/quantization/utils/quant_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,48 @@
11
"""This file is used for /tests and /benchmarks"""
2+
from typing import List
3+
24
import numpy
35
import torch
46

57
SUPPORTED_NUM_BITS = [4, 8]
68
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
79

10+
# Note: this is a hack. We should update each model to register the
11+
# stacked params and get it from there instead in a future PR.
12+
# fused_name: List[shard_name]
13+
FUSED_LAYER_NAME_MAPPING = {
14+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
15+
"gate_up_proj": ["gate_proj", "up_proj"]
16+
}
17+
18+
19+
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
20+
# prefix: model.layers.0.self_attn.q_proj
21+
# proj_name: q_proj
22+
proj_name = prefix.split(".")[-1]
23+
if proj_name in FUSED_LAYER_NAME_MAPPING:
24+
shard_prefixes = [
25+
prefix.replace(proj_name, shard_proj_name)
26+
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
27+
]
28+
29+
is_skipped = None
30+
for shard_prefix in shard_prefixes:
31+
is_shard_skipped = shard_prefix in ignored_layers
32+
33+
if is_skipped is None:
34+
is_skipped = is_shard_skipped
35+
elif is_shard_skipped != is_skipped:
36+
raise ValueError(
37+
f"Detected some but not all shards of {prefix} "
38+
"are quantized. All shards of fused layers "
39+
"to have the same precision.")
40+
else:
41+
is_skipped = prefix in ignored_layers
42+
43+
assert is_skipped is not None
44+
return is_skipped
45+
846

947
def get_pack_factor(num_bits):
1048
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"

0 commit comments

Comments
 (0)