Skip to content

Commit 489501f

Browse files
authored
Mxfp4 memory leak fixes (#2)
1 parent 6ed434c commit 489501f

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

vllm/model_executor/layers/quantization/quark/quark_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
350350
w13_quantizer(layer.w13_weight.data).to(float_dtype),
351351
requires_grad=False,
352352
)
353+
layer.w13_weight_scale = None
353354

354355
w2_quantizer = realquantizer.get_real_quantizer(
355356
qspec=weight_quant_spec,
@@ -366,6 +367,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
366367
w2_quantizer(layer.w2_weight.data).to(float_dtype),
367368
requires_grad=False,
368369
)
370+
layer.w2_weight_scale = None
371+
372+
# This call is necessary to release the scales memory.
373+
torch.cuda.empty_cache()
369374

370375
def apply(
371376
self,

vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
import vllm.envs as envs
99
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
10-
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
11-
OCP_MX_BLOCK_SIZE)
1210
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
1311
PackedvLLMParameter)
1412
from vllm.platforms import current_platform
1513

14+
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4
15+
1616
__all__ = ["QuarkW4A4MXFP4"]
1717

1818

@@ -37,16 +37,6 @@ def __init__(self, weight_quant_spec: Dict[str, Any],
3737
"MX-FP4 models. Please install it with `pip install "
3838
"amd-quark`.") from err
3939

40-
input_quant_spec = QuantizationSpec.from_dict(
41-
self.input_quant_spec)
42-
43-
self.input_quantizer = realquantizer.get_real_quantizer(
44-
qspec=input_quant_spec,
45-
quantizer=None,
46-
real_quantized=False,
47-
float_dtype=self.out_dtype,
48-
)
49-
5040
@classmethod
5141
def get_min_capability(cls) -> int:
5242
# lovelace and up
@@ -72,7 +62,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
7262
weight_quant_spec = QuantizationSpec.from_dict(
7363
self.weight_quant_spec)
7464

75-
self.weight_quantizer = realquantizer.get_real_quantizer(
65+
weight_quantizer = realquantizer.get_real_quantizer(
7666
qspec=weight_quant_spec,
7767
quantizer=None,
7868
real_quantized=True,
@@ -81,14 +71,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8171
scale_shape=layer.weight_scale.shape,
8272
zero_point_shape=None,
8373
)
84-
self.weight_quantizer.scale.data = layer.weight_scale.data
74+
weight_quantizer.scale.data = layer.weight_scale.data
8575

8676
if not envs.VLLM_QUARK_EMU_MEM_OPT:
8777
layer.weight = torch.nn.Parameter(
88-
self.weight_quantizer(layer.weight.data).to(
78+
weight_quantizer(layer.weight.data).to(
8979
self.out_dtype),
9080
requires_grad=False,
9181
)
82+
layer.weight_scale = None
83+
84+
# This call is necessary to release the scales memory.
85+
torch.cuda.empty_cache()
9286

9387
def create_weights(self, layer: torch.nn.Module,
9488
output_partition_sizes: List[int],
@@ -136,7 +130,7 @@ def apply_weights(self,
136130
dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype)
137131
else:
138132
dq_w = layer.weight
139-
qdq_x = self.input_quantizer(x)
133+
qdq_x, _ = per_token_group_quant_mxfp4(x, 32)
140134
return F.linear(qdq_x, dq_w, bias)
141135
else:
142136
raise NotImplementedError()

vllm/model_executor/layers/quantization/utils/mxfp4_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int):
2020
# TODO: there are other rounding strategies supported in quark and in the config.json that we do not check for here!
2121
scale = even_round(amax, "fp4")
2222

23-
x_qdq = scaled_fake_quantize(
23+
# Apply dequantize(quantize(x)).
24+
x = scaled_fake_quantize(
2425
"fp4",
2526
x,
2627
scale.to(x.device),
@@ -34,4 +35,4 @@ def per_token_group_quant_mxfp4(x: torch.Tensor, block_k: int):
3435
'None', # must be a string in quark hw_emulation_interface.py, why?
3536
)
3637

37-
return x_qdq, scale
38+
return x, scale

0 commit comments

Comments
 (0)