Skip to content

Commit 6ed434c

Browse files
committed
Add envar if dequant weight at load time
Signed-off-by: Bowen Bao <[email protected]>
1 parent 697905e commit 6ed434c

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
VLLM_ROCM_FP8_PADDING: bool = True
8383
VLLM_ROCM_MOE_PADDING: bool = True
8484
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
85+
VLLM_QUARK_EMU_MEM_OPT: bool = False
8586
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
8687
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
8788
VLLM_DISABLE_COMPILE_CACHE: bool = False
@@ -571,6 +572,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
571572
lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
572573
("true", "1")),
573574

575+
# If set, when running in Quark emulation mode, do not dequantize the
576+
# weights at load time. Instead, dequantize weights on-the-fly during
577+
# kernel execution.
578+
# This allows running larger models at the cost of slower inference.
579+
# This flag has no effect when not running in Quark emulation mode.
580+
"VLLM_QUARK_EMU_MEM_OPT":
581+
lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))),
582+
574583
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
575584
"Q_SCALE_CONSTANT":
576585
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55
import torch
66
import torch.nn.functional as F
77

8+
import vllm.envs as envs
89
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
10+
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
11+
OCP_MX_BLOCK_SIZE)
912
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
1013
PackedvLLMParameter)
1114
from vllm.platforms import current_platform
1215

13-
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE
14-
1516
__all__ = ["QuarkW4A4MXFP4"]
1617

1718

18-
19-
2019
class QuarkW4A4MXFP4(QuarkScheme):
2120

2221
def __init__(self, weight_quant_spec: Dict[str, Any],
@@ -48,7 +47,6 @@ def __init__(self, weight_quant_spec: Dict[str, Any],
4847
float_dtype=self.out_dtype,
4948
)
5049

51-
5250
@classmethod
5351
def get_min_capability(cls) -> int:
5452
# lovelace and up
@@ -74,7 +72,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
7472
weight_quant_spec = QuantizationSpec.from_dict(
7573
self.weight_quant_spec)
7674

77-
weight_quantizer = realquantizer.get_real_quantizer(
75+
self.weight_quantizer = realquantizer.get_real_quantizer(
7876
qspec=weight_quant_spec,
7977
quantizer=None,
8078
real_quantized=True,
@@ -83,12 +81,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8381
scale_shape=layer.weight_scale.shape,
8482
zero_point_shape=None,
8583
)
86-
weight_quantizer.scale.data = layer.weight_scale.data
84+
self.weight_quantizer.scale.data = layer.weight_scale.data
8785

88-
layer.weight = torch.nn.Parameter(
89-
weight_quantizer(layer.weight.data).to(self.out_dtype),
90-
requires_grad=False,
91-
)
86+
if not envs.VLLM_QUARK_EMU_MEM_OPT:
87+
layer.weight = torch.nn.Parameter(
88+
self.weight_quantizer(layer.weight.data).to(
89+
self.out_dtype),
90+
requires_grad=False,
91+
)
9292

9393
def create_weights(self, layer: torch.nn.Module,
9494
output_partition_sizes: List[int],
@@ -132,7 +132,11 @@ def apply_weights(self,
132132
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
133133

134134
if self.emulate:
135+
if envs.VLLM_QUARK_EMU_MEM_OPT:
136+
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
137+
else:
138+
dq_w = layer.weight
135139
qdq_x = self.input_quantizer(x)
136-
return F.linear(qdq_x, layer.weight, bias)
140+
return F.linear(qdq_x, dq_w, bias)
137141
else:
138142
raise NotImplementedError()

0 commit comments

Comments
 (0)