Skip to content

Commit fadffba

Browse files
BowenBaofxmarty-amd
andcommitted
Move all kernels into Quark (#3)
Co-authored-by: Felix Marty <[email protected]> Signed-off-by: Felix Marty <[email protected]>
1 parent 09fafb6 commit fadffba

File tree

6 files changed

+35
-1027
lines changed

6 files changed

+35
-1027
lines changed

vllm/envs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -593,11 +593,6 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
593593
"VLLM_QUARK_EMU_MEM_OPT":
594594
lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))),
595595

596-
# Selects the Q/DQ/QDQ implementation to use with mxfp4.
597-
# Available: "hip", "torch", "triton".
598-
"VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM":
599-
lambda: os.getenv("VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM", "hip"),
600-
601596
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
602597
"Q_SCALE_CONSTANT":
603598
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
per_token_group_quant_fp8)
2121
from vllm.model_executor.layers.quantization.utils.int8_utils import (
2222
per_token_group_quant_int8, per_token_quant_int8)
23-
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4, per_token_group_dequant_mxfp4
23+
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
24+
quant_dequant_mxfp4,
25+
dequant_mxfp4,
26+
)
2427
from vllm.platforms import current_platform
2528
from vllm.utils import direct_register_custom_op
2629

@@ -1232,7 +1235,7 @@ def moe_kernel_prepare_input(
12321235
elif use_mxfp4_w4a4:
12331236
assert block_shape is None
12341237
if not current_platform.supports_mx():
1235-
A = per_token_group_quant_mxfp4(A, OCP_MX_BLOCK_SIZE)
1238+
A = quant_dequant_mxfp4(A)
12361239
else:
12371240
raise NotImplementedError()
12381241
else:
@@ -1345,13 +1348,11 @@ def fused_experts_impl(hidden_states: torch.Tensor,
13451348
if use_mxfp4_w4a4 and not current_platform.supports_mx(
13461349
) and envs.VLLM_QUARK_EMU_MEM_OPT:
13471350
# Weight has to be dequantized for mxfp4 emulation.
1348-
w1 = per_token_group_dequant_mxfp4(w1, w1_scale, OCP_MX_BLOCK_SIZE,
1349-
hidden_states.dtype)
1351+
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
13501352
w1_scale = None
1351-
w2 = per_token_group_dequant_mxfp4(w2, w2_scale, OCP_MX_BLOCK_SIZE,
1352-
hidden_states.dtype)
1353+
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
13531354
w2_scale = None
1354-
1355+
13551356
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
13561357
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
13571358
min((chunk + 1) * CHUNK_SIZE,

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
177177
is_static_weight = not weight_quant.get("is_dynamic")
178178
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
179179
in ["per_tensor", "per_channel"])
180-
180+
181181
if not (is_fp8_dtype and is_static_weight
182182
and is_per_tensor_or_channel_weight):
183183
return False
@@ -325,7 +325,6 @@ def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
325325
is_static_input_scheme=True,
326326
input_symmetric=input_config.get("symmetric"))
327327
elif self._is_mx_fp4(weight_config, input_config):
328-
logger.info(f"Using VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM='{envs.VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM}'.")
329328
return QuarkW4A4MXFP4(weight_config, input_config)
330329

331330
raise NotImplementedError("No quark compatible scheme was found. "

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
1212
FusedMoeWeightScaleSupported)
13-
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
14-
OCP_MX_BLOCK_SIZE)
13+
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE, dequant_mxfp4
1514
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
1615
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
1716
from vllm.model_executor.utils import set_weight_attrs
1817
from vllm.platforms import current_platform
19-
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import SUPPORTED_IMPLEMS
2018

2119
logger = init_logger(__name__)
2220

@@ -263,9 +261,6 @@ def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
263261
self.static_input_scales = not self.input_quant.get("is_dynamic")
264262
self.emulate = not current_platform.supports_mx()
265263

266-
if envs.VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM not in SUPPORTED_IMPLEMS:
267-
raise ValueError(f"VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM='{envs.VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM}' is not supported, only {SUPPORTED_IMPLEMS} are.")
268-
269264
def create_weights(self, layer: torch.nn.Module, num_experts: int,
270265
hidden_size: int, intermediate_size_per_partition: int,
271266
params_dtype: torch.dtype, **extra_weight_attrs):
@@ -327,54 +322,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
327322
float_dtype = torch.get_default_dtype()
328323

329324
if self.emulate and not envs.VLLM_QUARK_EMU_MEM_OPT:
330-
try:
331-
from quark.torch.export.nn.modules import realquantizer
332-
from quark.torch.quantization.config.config import (
333-
QuantizationSpec)
334-
except ImportError as err:
335-
raise ImportError(
336-
"The package `amd-quark` is required to use AMD Quark "
337-
"MX-FP4 models. Please install it with `pip install "
338-
"amd-quark`.") from err
339-
340-
weight_quant_spec = QuantizationSpec.from_dict(self.weight_quant)
341-
342325
# Unpack and dequantize the weights (the operators are in high-precision, with simulated quantization).
343-
w13_quantizer = realquantizer.get_real_quantizer(
344-
qspec=weight_quant_spec,
345-
quantizer=None,
346-
real_quantized=True,
347-
reorder=False, # TODO: load from config
348-
float_dtype=float_dtype,
349-
scale_shape=layer.w13_weight_scale.shape,
350-
zero_point_shape=None,
351-
)
352-
w13_quantizer.scale.data = layer.w13_weight_scale.data
353-
354326
layer.w13_weight = torch.nn.Parameter(
355-
w13_quantizer(layer.w13_weight.data).to(float_dtype),
327+
dequant_mxfp4(layer.w13_weight.data, layer.w13_weight_scale.data, float_dtype),
356328
requires_grad=False,
357329
)
358330
layer.w13_weight_scale = None
359331

360-
w2_quantizer = realquantizer.get_real_quantizer(
361-
qspec=weight_quant_spec,
362-
quantizer=None,
363-
real_quantized=True,
364-
reorder=False, # TODO: load from config
365-
float_dtype=float_dtype,
366-
scale_shape=layer.w2_weight_scale.shape,
367-
zero_point_shape=None,
368-
)
369-
w2_quantizer.scale.data = layer.w2_weight_scale.data
370-
371332
layer.w2_weight = torch.nn.Parameter(
372-
w2_quantizer(layer.w2_weight.data).to(float_dtype),
333+
dequant_mxfp4(layer.w2_weight.data, layer.w2_weight_scale.data, float_dtype),
373334
requires_grad=False,
374335
)
375336
layer.w2_weight_scale = None
376337

377338
# This call is necessary to release the scales memory.
339+
# TODO: is it still?
378340
torch.cuda.empty_cache()
379341

380342
def apply(

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
import vllm.envs as envs
99
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
10-
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4, per_token_group_dequant_mxfp4, SUPPORTED_IMPLEMS
10+
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
11+
OCP_MX_BLOCK_SIZE,
12+
quant_dequant_mxfp4,
13+
dequant_mxfp4,
14+
)
1115
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
1216
PackedvLLMParameter)
1317
from vllm.platforms import current_platform
@@ -25,9 +29,6 @@ def __init__(self, weight_quant_spec: Dict[str, Any],
2529
self.input_quant_spec = input_quant_spec
2630
self.emulate = not current_platform.supports_mx()
2731

28-
if envs.VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM not in SUPPORTED_IMPLEMS:
29-
raise ValueError(f"VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM='{envs.VLLM_QUARK_MXFP4_Q_DQ_QDQ_IMPLEM}' is not supported, only {SUPPORTED_IMPLEMS} are.")
30-
3132
@classmethod
3233
def get_min_capability(cls) -> int:
3334
return 70
@@ -38,40 +39,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
3839
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
3940
requires_grad=False)
4041

41-
if self.emulate:
42-
try:
43-
from quark.torch.export.nn.modules import realquantizer
44-
from quark.torch.quantization.config.config import (
45-
QuantizationSpec)
46-
except ImportError as err:
47-
raise ImportError(
48-
"The package `amd-quark` is required to use AMD Quark "
49-
"MX-FP4 models. Please install it with `pip install "
50-
"amd-quark`.") from err
51-
52-
weight_quant_spec = QuantizationSpec.from_dict(
53-
self.weight_quant_spec)
54-
55-
weight_quantizer = realquantizer.get_real_quantizer(
56-
qspec=weight_quant_spec,
57-
quantizer=None,
58-
real_quantized=True,
59-
reorder=False,
60-
float_dtype=self.out_dtype,
61-
scale_shape=layer.weight_scale.shape,
62-
zero_point_shape=None,
42+
if self.emulate and not envs.VLLM_QUARK_EMU_MEM_OPT:
43+
layer.weight = torch.nn.Parameter(
44+
dequant_mxfp4(layer.weight.data, layer.weight_scale.data, self.out_dtype),
45+
requires_grad=False,
6346
)
64-
weight_quantizer.scale.data = layer.weight_scale.data
47+
layer.weight_scale = None
6548

66-
if not envs.VLLM_QUARK_EMU_MEM_OPT:
67-
layer.weight = torch.nn.Parameter(
68-
weight_quantizer(layer.weight.data).to(self.out_dtype),
69-
requires_grad=False,
70-
)
71-
layer.weight_scale = None
72-
73-
# This call is necessary to release the scales memory.
74-
torch.cuda.empty_cache()
49+
# This call is necessary to release the scales memory.
50+
# TODO: is it still?
51+
torch.cuda.empty_cache()
7552

7653
def create_weights(self, layer: torch.nn.Module,
7754
output_partition_sizes: List[int],
@@ -116,11 +93,11 @@ def apply_weights(self,
11693

11794
if self.emulate:
11895
if envs.VLLM_QUARK_EMU_MEM_OPT:
119-
dq_w = per_token_group_dequant_mxfp4(layer.weight, layer.weight_scale, OCP_MX_BLOCK_SIZE, x.dtype)
96+
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
12097
else:
12198
dq_w = layer.weight
122-
123-
x = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE)
99+
100+
x = quant_dequant_mxfp4(x)
124101

125102
return F.linear(x, dq_w, bias)
126103
else:

0 commit comments

Comments
 (0)