Skip to content

Commit 1c9a9b9

Browse files
BowenBaoMu Huai
authored andcommitted
[Quantization] Quark MXFP4 format loading (vllm-project#16943)
Signed-off-by: Mu Huai <[email protected]>
1 parent 32309f7 commit 1c9a9b9

File tree

9 files changed

+289
-3
lines changed

9 files changed

+289
-3
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# flake8: noqa
3+
"""Tests Quark mxfp4 models against ground truth generation
4+
"""
5+
import pytest
6+
7+
from vllm import LLM, SamplingParams
8+
9+
MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"]
10+
11+
EXPECTED_STRS_MAP = {
12+
"amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [
13+
'\n### Key Features\n\n* **High-throughput Inference**: vLL',
14+
'\nArtificial intelligence (AI) has evolved significantly since its inception in the 1',
15+
'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been',
16+
'A neural network is a machine learning model inspired by the structure of the human brain. It consists of',
17+
'\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol',
18+
'\nThe COVID-19 pandemic has had a profound impact on global economic structures and business',
19+
'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th',
20+
" everybody knows this proverbial saying, but did you know that it's not entirely accurate?",
21+
]
22+
}
23+
24+
25+
@pytest.mark.skip(reason="Model to be released in the future")
26+
@pytest.mark.quant_model
27+
@pytest.mark.parametrize("model_name", MODELS)
28+
def test_models(example_prompts, model_name) -> None:
29+
sampling_params = SamplingParams(max_tokens=20, temperature=0)
30+
llm = LLM(
31+
model=model_name,
32+
kv_cache_dtype="fp8",
33+
quantization="quark",
34+
)
35+
outputs = llm.generate(example_prompts, sampling_params)
36+
for i, output in enumerate(outputs):
37+
output_str = output.outputs[0].text
38+
expected_str = EXPECTED_STRS_MAP[model_name][i]
39+
assert expected_str == output_str, (
40+
f"Expected: {expected_str!r}\nvLLM: {output_str!r}")

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
VLLM_ROCM_FP8_PADDING: bool = True
8585
VLLM_ROCM_MOE_PADDING: bool = True
8686
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
87+
VLLM_QUARK_EMU_MEM_OPT: bool = False
8788
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
8889
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
8990
VLLM_DISABLE_COMPILE_CACHE: bool = False
@@ -583,6 +584,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
583584
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
584585
("true", "1")),
585586

587+
# If set, when running in Quark emulation mode, do not dequantize the
588+
# weights at load time. Instead, dequantize weights on-the-fly during
589+
# kernel execution.
590+
# This allows running larger models at the cost of slower inference.
591+
# This flag has no effect when not running in Quark emulation mode.
592+
"VLLM_QUARK_EMU_MEM_OPT":
593+
lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))),
594+
586595
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
587596
"Q_SCALE_CONSTANT":
588597
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77

8+
from vllm.logger import init_logger
89
from vllm.model_executor.layers.fused_moe import FusedMoE
910
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1011
UnquantizedLinearMethod)
@@ -15,13 +16,15 @@
1516
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
1617
QuarkMoEMethod)
1718
from vllm.model_executor.layers.quantization.quark.schemes import (
18-
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
19+
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
1920
from vllm.model_executor.layers.quantization.quark.utils import (
2021
deep_compare, should_ignore_layer)
2122
from vllm.platforms import current_platform
2223

2324
__all__ = ["QuarkLinearMethod"]
2425

26+
logger = init_logger(__name__)
27+
2528

2629
class QuarkConfig(QuantizationConfig):
2730

@@ -67,6 +70,7 @@ def get_quant_method(self, layer: torch.nn.Module,
6770
return QuarkLinearMethod(self)
6871
if isinstance(layer, Attention):
6972
return QuarkKVCacheMethod(self)
73+
7074
if isinstance(layer, FusedMoE):
7175
return QuarkMoEMethod.get_moe_method(self,
7276
module=layer,
@@ -205,6 +209,54 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
205209
# Only symmetric weight quantization supported.
206210
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
207211

212+
def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]],
213+
input_quant: Optional[Dict[str, Any]]) -> bool:
214+
# Confirm weights and input quantized.
215+
if weight_quant is None or input_quant is None:
216+
logger.debug("Quark model is not in MX-FP4 format: "
217+
"weight_quant or input_quant not set")
218+
return False
219+
220+
# Input and weight dtype needs to be fp4.
221+
if weight_quant.get("dtype") != "fp4" or input_quant.get(
222+
"dtype") != "fp4":
223+
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
224+
return False
225+
226+
# Input and weight qscheme needs to be per group.
227+
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
228+
"qscheme") != "per_group":
229+
logger.debug("Quark model is not in MX-FP4 format: not per_group")
230+
return False
231+
232+
# Input and weight group size needs to be 32.
233+
if weight_quant.get("group_size") != 32 or input_quant.get(
234+
"group_size") != 32:
235+
logger.debug(
236+
"Quark model is not in MX-FP4 format: not group_size=32")
237+
return False
238+
239+
# Weights need to use static quantization.
240+
if weight_quant.get("is_dynamic") is True:
241+
logger.debug(
242+
"Quark model is not in MX-FP4 format: not weight static")
243+
return False
244+
245+
# Activations need to use dynamic quantization.
246+
if input_quant.get("is_dynamic") is False:
247+
logger.debug(
248+
"Quark model is not in MX-FP4 format: not activation dynamic")
249+
return False
250+
251+
# Activations and weight scales need to be in e8m0 format.
252+
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
253+
"scale_format") != "e8m0":
254+
logger.debug(
255+
"Quark model is not in MX-FP4 format: not scale_format e8m0")
256+
return False
257+
258+
return True
259+
208260
def _find_matched_config(self, layer_name: str,
209261
module: torch.nn.Module) -> Dict[str, Any]:
210262

@@ -269,6 +321,8 @@ def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
269321
return QuarkW8A8Int8(qscheme=weight_qscheme,
270322
is_static_input_scheme=True,
271323
input_symmetric=input_config.get("symmetric"))
324+
elif self._is_mx_fp4(weight_config, input_config):
325+
return QuarkW4A4MXFP4(weight_config, input_config)
272326

273327
raise NotImplementedError("No quark compatible scheme was found. "
274328
f"Weight config: {weight_config}, "
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from .quark_scheme import QuarkScheme
4+
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
45
from .quark_w8a8_fp8 import QuarkW8A8Fp8
56
from .quark_w8a8_int8 import QuarkW8A8Int8
67

7-
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"]
8+
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Any, Callable, Dict, List, Optional
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
import vllm.envs as envs
9+
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
10+
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
11+
OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4)
12+
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
13+
PackedvLLMParameter)
14+
from vllm.platforms import current_platform
15+
16+
__all__ = ["QuarkW4A4MXFP4"]
17+
18+
19+
class QuarkW4A4MXFP4(QuarkScheme):
20+
21+
def __init__(self, weight_quant_spec: Dict[str, Any],
22+
input_quant_spec: Dict[str, Any]):
23+
self.out_dtype = torch.get_default_dtype()
24+
self.qscheme = "per_group"
25+
self.weight_quant_spec = weight_quant_spec
26+
self.input_quant_spec = input_quant_spec
27+
self.emulate = not current_platform.supports_mx()
28+
29+
@classmethod
30+
def get_min_capability(cls) -> int:
31+
return 70
32+
33+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
34+
layer.weight = torch.nn.Parameter(layer.weight.data,
35+
requires_grad=False)
36+
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
37+
requires_grad=False)
38+
39+
if self.emulate:
40+
try:
41+
from quark.torch.export.nn.modules import realquantizer
42+
from quark.torch.quantization.config.config import (
43+
QuantizationSpec)
44+
except ImportError as err:
45+
raise ImportError(
46+
"The package `amd-quark` is required to use AMD Quark "
47+
"MX-FP4 models. Please install it with `pip install "
48+
"amd-quark`.") from err
49+
50+
weight_quant_spec = QuantizationSpec.from_dict(
51+
self.weight_quant_spec)
52+
53+
weight_quantizer = realquantizer.get_real_quantizer(
54+
qspec=weight_quant_spec,
55+
quantizer=None,
56+
real_quantized=True,
57+
reorder=False,
58+
float_dtype=self.out_dtype,
59+
scale_shape=layer.weight_scale.shape,
60+
zero_point_shape=None,
61+
)
62+
weight_quantizer.scale.data = layer.weight_scale.data
63+
64+
if not envs.VLLM_QUARK_EMU_MEM_OPT:
65+
layer.weight = torch.nn.Parameter(
66+
weight_quantizer(layer.weight.data).to(self.out_dtype),
67+
requires_grad=False,
68+
)
69+
else:
70+
self.weight_quantizer = weight_quantizer
71+
layer.weight_scale = None
72+
73+
# This call is necessary to release the scales memory.
74+
torch.cuda.empty_cache()
75+
76+
def create_weights(self, layer: torch.nn.Module,
77+
output_partition_sizes: List[int],
78+
input_size_per_partition: int,
79+
params_dtype: torch.dtype, weight_loader: Callable,
80+
**kwargs):
81+
output_size_per_partition = sum(output_partition_sizes)
82+
layer.logical_widths = output_partition_sizes
83+
84+
# WEIGHT
85+
weight = PackedvLLMParameter(
86+
data=torch.empty(
87+
output_size_per_partition,
88+
input_size_per_partition // 2,
89+
dtype=torch.uint8,
90+
),
91+
input_dim=1,
92+
output_dim=0,
93+
packed_dim=1,
94+
packed_factor=2,
95+
weight_loader=weight_loader,
96+
)
97+
layer.register_parameter("weight", weight)
98+
99+
# WEIGHT SCALE
100+
weight_scale = GroupQuantScaleParameter(
101+
data=torch.empty(
102+
output_size_per_partition,
103+
input_size_per_partition // OCP_MX_BLOCK_SIZE,
104+
dtype=torch.uint8,
105+
),
106+
input_dim=1,
107+
output_dim=0,
108+
weight_loader=weight_loader,
109+
)
110+
layer.register_parameter("weight_scale", weight_scale)
111+
112+
def apply_weights(self,
113+
layer: torch.nn.Module,
114+
x: torch.Tensor,
115+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
116+
117+
if self.emulate:
118+
if envs.VLLM_QUARK_EMU_MEM_OPT:
119+
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
120+
else:
121+
dq_w = layer.weight
122+
qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
123+
return F.linear(qdq_x, dq_w, bias)
124+
else:
125+
raise NotImplementedError()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Tuple
3+
4+
import torch
5+
6+
OCP_MX_BLOCK_SIZE = 32
7+
8+
9+
def per_token_group_quant_mxfp4(x: torch.Tensor,
10+
block_k: int,
11+
scale_calculation_mode: str = "even"
12+
) -> Tuple[torch.Tensor, torch.Tensor]:
13+
try:
14+
from quark.torch.kernel.hw_emulation.hw_emulation_interface import (
15+
fake_quantize_fp4_fp6_per_group_with_scale)
16+
from quark.torch.quantization.utils import (even_round,
17+
reshape_to_blocks)
18+
except ImportError as err:
19+
raise ImportError("The package `amd-quark` is required to use "
20+
"MX-FP4 models. Please install it with `pip install "
21+
"amd-quark`.") from err
22+
23+
axis = -1
24+
block_x = reshape_to_blocks(x, block_k, axis)
25+
amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True)
26+
amax = amax.squeeze(-1)
27+
28+
# TODO: there are other rounding strategies supported in quark and in the
29+
# config.json that we do not check for here!
30+
if scale_calculation_mode != "even":
31+
raise NotImplementedError(
32+
f"Scale calculation mode {scale_calculation_mode} is not yet "
33+
"supported in MX-FP4 quantization")
34+
scale = even_round(amax, "fp4")
35+
36+
# Apply dequantize(quantize(x)).
37+
x = fake_quantize_fp4_fp6_per_group_with_scale(
38+
x,
39+
scale.to(x.device),
40+
axis=axis,
41+
group_size=block_k,
42+
quant_dtype="fp4",
43+
)
44+
45+
return x, scale

vllm/model_executor/model_loader/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_model_architecture(
220220
# Special handling for quantized Mixtral.
221221
# FIXME(woosuk): This is a temporary hack.
222222
mixtral_supported = [
223-
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
223+
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
224224
]
225225

226226
if (model_config.quantization is not None

vllm/platforms/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,13 @@ def get_device_communicator_cls(cls) -> str:
339339
"""
340340
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
341341

342+
@classmethod
343+
def supports_mx(cls) -> bool:
344+
"""
345+
Returns whether the current platform supports MX types.
346+
"""
347+
return False
348+
342349
@classmethod
343350
def supports_fp8(cls) -> bool:
344351
"""

vllm/platforms/rocm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ def get_current_memory_usage(cls,
327327
def get_device_communicator_cls(cls) -> str:
328328
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
329329

330+
@classmethod
331+
def supports_mx(cls) -> bool:
332+
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
333+
return any(gfx in gcn_arch for gfx in ["gfx95"])
334+
330335
@classmethod
331336
def supports_fp8(cls) -> bool:
332337
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName

0 commit comments

Comments
 (0)