Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from pydantic import BaseModel, Field
from torch.nn import Module

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)


class CompressionFormat(Enum):
dense = "dense"
Expand Down Expand Up @@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS


# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}


def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
Expand All @@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
39 changes: 3 additions & 36 deletions vllm/model_executor/layers/quantization/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,15 @@
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

logger = init_logger(__name__)

# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}


class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
Expand Down Expand Up @@ -62,37 +56,10 @@ def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)

def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignore_list

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in self.ignore_list

assert is_skipped is not None
return is_skipped

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
if is_layer_skipped(prefix, self.ignore_list):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
fused_moe)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
Expand All @@ -32,6 +35,7 @@ def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
Expand All @@ -41,6 +45,7 @@ def __init__(
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []

@classmethod
def get_name(cls) -> str:
Expand All @@ -63,14 +68,18 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
activation_scheme=activation_scheme,
ignored_layers=ignored_layers)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
Expand Down
38 changes: 38 additions & 0 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,48 @@
"""This file is used for /tests and /benchmarks"""
from typing import List

import numpy
import torch

SUPPORTED_NUM_BITS = [4, 8]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}


def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
]

is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers

if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in ignored_layers

assert is_skipped is not None
return is_skipped


def get_pack_factor(num_bits):
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
Expand Down