1111 set_weight_attrs )
1212from vllm .model_executor .layers .quantization .base_config import (
1313 QuantizationConfig , QuantizeMethodBase )
14- from vllm .model_executor .layers .quantization .utils .marlin_utils import (
15- apply_awq_marlin_linear , awq_to_marlin_zero_points , check_marlin_supported ,
16- marlin_make_empty_g_idx , marlin_make_workspace , marlin_moe_permute_scales ,
17- marlin_permute_scales , moe_awq_to_marlin_zero_points )
1814from vllm .model_executor .layers .quantization .utils import replace_parameter
1915from vllm .model_executor .layers .quantization .utils .marlin_utils import (
2016 apply_awq_marlin_linear , awq_to_marlin_zero_points , check_marlin_supported ,
21- marlin_make_empty_g_idx , marlin_make_workspace , marlin_permute_scales ,
17+ marlin_make_empty_g_idx , marlin_make_workspace , marlin_moe_permute_scales ,
18+ marlin_permute_scales , moe_awq_to_marlin_zero_points ,
2219 verify_marlin_supported , verify_marlin_supports_shape )
2320from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
2421from vllm .model_executor .parameter import (GroupQuantScaleParameter ,
@@ -379,7 +376,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
379376 size_n = layer .w13_qweight .shape [2 ] * self .quant_config .pack_factor ,
380377 num_bits = self .quant_config .weight_bits ,
381378 )
382- replace_tensor (layer , "w13_qweight" , marlin_w13_qweight )
379+ replace_parameter (layer , "w13_qweight" , marlin_w13_qweight )
383380
384381 marlin_w2_qweight = ops .awq_marlin_moe_repack (
385382 layer .w2_qweight ,
@@ -388,7 +385,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
388385 size_n = layer .w2_qweight .shape [2 ] * self .quant_config .pack_factor ,
389386 num_bits = self .quant_config .weight_bits ,
390387 )
391- replace_tensor (layer , "w2_qweight" , marlin_w2_qweight )
388+ replace_parameter (layer , "w2_qweight" , marlin_w2_qweight )
392389
393390 # Why does this take the intermediate size for size_k?
394391 marlin_w13_scales = marlin_moe_permute_scales (
@@ -398,29 +395,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
398395 group_size = self .quant_config .group_size ,
399396 )
400397
401- replace_tensor (layer , "w13_scales" , marlin_w13_scales )
398+ replace_parameter (layer , "w13_scales" , marlin_w13_scales )
402399
403400 marlin_w2_scales = marlin_moe_permute_scales (
404401 s = layer .w2_scales ,
405402 size_k = layer .intermediate_size_per_partition ,
406403 size_n = layer .w2_scales .shape [2 ],
407404 group_size = self .quant_config .group_size ,
408405 )
409- replace_tensor (layer , "w2_scales" , marlin_w2_scales )
406+ replace_parameter (layer , "w2_scales" , marlin_w2_scales )
410407
411408 marlin_w13_zp = moe_awq_to_marlin_zero_points (
412409 layer .w13_qzeros ,
413410 size_k = layer .w13_qzeros .shape [1 ],
414411 size_n = layer .w13_qzeros .shape [2 ] * self .quant_config .pack_factor ,
415412 num_bits = self .quant_config .weight_bits )
416- replace_tensor (layer , "w13_qzeros" , marlin_w13_zp )
413+ replace_parameter (layer , "w13_qzeros" , marlin_w13_zp )
417414
418415 marlin_w2_zp = moe_awq_to_marlin_zero_points (
419416 layer .w2_qzeros ,
420417 size_k = layer .w2_qzeros .shape [1 ],
421418 size_n = layer .w2_qzeros .shape [2 ] * self .quant_config .pack_factor ,
422419 num_bits = self .quant_config .weight_bits )
423- replace_tensor (layer , "w2_qzeros" , marlin_w2_zp )
420+ replace_parameter (layer , "w2_qzeros" , marlin_w2_zp )
424421
425422 def apply (
426423 self ,
0 commit comments