Skip to content

Commit bbf575e

Browse files
committed
use replace_parameters; clean-up
1 parent 793b065 commit bbf575e

1 file changed

Lines changed: 8 additions & 11 deletions

File tree

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,11 @@
1111
set_weight_attrs)
1212
from vllm.model_executor.layers.quantization.base_config import (
1313
QuantizationConfig, QuantizeMethodBase)
14-
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
15-
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
16-
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
17-
marlin_permute_scales, moe_awq_to_marlin_zero_points)
1814
from vllm.model_executor.layers.quantization.utils import replace_parameter
1915
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
2016
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
21-
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
17+
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
18+
marlin_permute_scales, moe_awq_to_marlin_zero_points,
2219
verify_marlin_supported, verify_marlin_supports_shape)
2320
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2421
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@@ -379,7 +376,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
379376
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
380377
num_bits=self.quant_config.weight_bits,
381378
)
382-
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
379+
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
383380

384381
marlin_w2_qweight = ops.awq_marlin_moe_repack(
385382
layer.w2_qweight,
@@ -388,7 +385,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
388385
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
389386
num_bits=self.quant_config.weight_bits,
390387
)
391-
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
388+
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
392389

393390
# Why does this take the intermediate size for size_k?
394391
marlin_w13_scales = marlin_moe_permute_scales(
@@ -398,29 +395,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
398395
group_size=self.quant_config.group_size,
399396
)
400397

401-
replace_tensor(layer, "w13_scales", marlin_w13_scales)
398+
replace_parameter(layer, "w13_scales", marlin_w13_scales)
402399

403400
marlin_w2_scales = marlin_moe_permute_scales(
404401
s=layer.w2_scales,
405402
size_k=layer.intermediate_size_per_partition,
406403
size_n=layer.w2_scales.shape[2],
407404
group_size=self.quant_config.group_size,
408405
)
409-
replace_tensor(layer, "w2_scales", marlin_w2_scales)
406+
replace_parameter(layer, "w2_scales", marlin_w2_scales)
410407

411408
marlin_w13_zp = moe_awq_to_marlin_zero_points(
412409
layer.w13_qzeros,
413410
size_k=layer.w13_qzeros.shape[1],
414411
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
415412
num_bits=self.quant_config.weight_bits)
416-
replace_tensor(layer, "w13_qzeros", marlin_w13_zp)
413+
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
417414

418415
marlin_w2_zp = moe_awq_to_marlin_zero_points(
419416
layer.w2_qzeros,
420417
size_k=layer.w2_qzeros.shape[1],
421418
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
422419
num_bits=self.quant_config.weight_bits)
423-
replace_tensor(layer, "w2_qzeros", marlin_w2_zp)
420+
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
424421

425422
def apply(
426423
self,

0 commit comments

Comments
 (0)