1- from typing import Any , Dict , List , Optional
1+ from typing import Any , Callable , Dict , List , Optional
22
33import torch
4+ from torch .nn import Parameter
45
56from vllm import _custom_ops as ops
67from vllm .logger import init_logger
7- from vllm .model_executor .layers .linear import LinearBase , LinearMethodBase
8+ from vllm .model_executor .layers .fused_moe .layer import (
9+ FusedMoE , FusedMoEMethodBase , FusedMoeWeightScaleSupported )
10+ from vllm .model_executor .layers .linear import (LinearBase , LinearMethodBase ,
11+ set_weight_attrs )
812from vllm .model_executor .layers .quantization .base_config import (
9- QuantizationConfig )
13+ QuantizationConfig , QuantizeMethodBase )
1014from vllm .model_executor .layers .quantization .utils import replace_parameter
1115from vllm .model_executor .layers .quantization .utils .marlin_utils import (
1216 apply_awq_marlin_linear , awq_to_marlin_zero_points , check_marlin_supported ,
13- marlin_make_empty_g_idx , marlin_make_workspace , marlin_permute_scales ,
17+ marlin_make_empty_g_idx , marlin_make_workspace , marlin_moe_permute_scales ,
18+ marlin_permute_scales , moe_awq_to_marlin_zero_points ,
1419 verify_marlin_supported , verify_marlin_supports_shape )
1520from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
1621from vllm .model_executor .parameter import (GroupQuantScaleParameter ,
@@ -35,12 +40,13 @@ def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
3540 self .group_size = group_size
3641 self .has_zp = has_zp
3742 self .lm_head_quantized = lm_head_quantized
43+ self .weight_bits = weight_bits
3844
39- if weight_bits not in self .TYPE_MAP :
40- raise ValueError (f"Unsupported num_bits = { weight_bits } . "
45+ if self . weight_bits not in self .TYPE_MAP :
46+ raise ValueError (f"Unsupported num_bits = { self . weight_bits } . "
4147 f"Supported num_bits = { self .TYPE_MAP .keys ()} " )
4248
43- self .quant_type = self .TYPE_MAP [weight_bits ]
49+ self .quant_type = self .TYPE_MAP [self . weight_bits ]
4450
4551 verify_marlin_supported (self .quant_type ,
4652 group_size = self .group_size ,
@@ -98,10 +104,12 @@ def override_quantization_method(cls, hf_quant_cfg,
98104 return None
99105
100106 def get_quant_method (self , layer : torch .nn .Module ,
101- prefix : str ) -> Optional ["AWQMarlinLinearMethod " ]:
107+ prefix : str ) -> Optional ["QuantizeMethodBase " ]:
102108 if (isinstance (layer , LinearBase ) or
103109 (isinstance (layer , ParallelLMHead ) and self .lm_head_quantized )):
104110 return AWQMarlinLinearMethod (self )
111+ elif isinstance (layer , FusedMoE ):
112+ return AWQMoEMethod (self )
105113 return None
106114
107115 def get_scaled_act_names (self ) -> List [str ]:
@@ -271,4 +279,182 @@ def apply(
271279 quant_type = self .quant_config .quant_type ,
272280 output_size_per_partition = layer .output_size_per_partition ,
273281 input_size_per_partition = layer .input_size_per_partition ,
274- bias = bias )
282+ bias = bias )
283+
284+
285+ class AWQMoEMethod (FusedMoEMethodBase ):
286+
287+ def __init__ (self , quant_config : AWQMarlinConfig ):
288+ self .quant_config = quant_config
289+
290+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
291+ hidden_size : int , intermediate_size : int ,
292+ params_dtype : torch .dtype , ** extra_weight_attrs ):
293+ extra_weight_attrs .update ({
294+ "is_transposed" :
295+ True ,
296+ "quant_method" :
297+ FusedMoeWeightScaleSupported .GROUP .value ,
298+ })
299+
300+ w13_qweight = Parameter (torch .empty (num_experts ,
301+ hidden_size ,
302+ 2 * intermediate_size //
303+ self .quant_config .pack_factor ,
304+ dtype = torch .int32 ),
305+ requires_grad = False )
306+ layer .register_parameter ("w13_qweight" , w13_qweight )
307+ set_weight_attrs (w13_qweight , extra_weight_attrs )
308+
309+ w2_qweight = Parameter (torch .empty (num_experts ,
310+ intermediate_size ,
311+ hidden_size //
312+ self .quant_config .pack_factor ,
313+ dtype = torch .int32 ),
314+ requires_grad = False )
315+ layer .register_parameter ("w2_qweight" , w2_qweight )
316+ set_weight_attrs (w2_qweight , extra_weight_attrs )
317+
318+ num_groups_w13 = hidden_size // self .quant_config .group_size
319+ num_groups_w2 = intermediate_size // self .quant_config .group_size
320+
321+ # WEIGHT_SCALES
322+ # Allocate 2 scales for w1 and w3 respectively.
323+ w13_scales = Parameter (torch .empty (num_experts ,
324+ num_groups_w13 ,
325+ intermediate_size * 2 ,
326+ dtype = params_dtype ),
327+ requires_grad = False )
328+ layer .register_parameter ("w13_scales" , w13_scales )
329+ set_weight_attrs (w13_scales , extra_weight_attrs )
330+
331+ w2_scales = Parameter (torch .empty (num_experts ,
332+ num_groups_w2 ,
333+ hidden_size ,
334+ dtype = params_dtype ),
335+ requires_grad = False )
336+ layer .register_parameter ("w2_scales" , w2_scales )
337+ set_weight_attrs (w2_scales , extra_weight_attrs )
338+
339+ # WEIGHT_ZERO_POINT
340+ # Allocate 2 zero points for w1 and w3 respectively.
341+ w13_qzeros = Parameter (torch .empty (num_experts ,
342+ num_groups_w13 ,
343+ 2 * intermediate_size //
344+ self .quant_config .pack_factor ,
345+ dtype = torch .int32 ),
346+ requires_grad = False )
347+ layer .register_parameter ("w13_qzeros" , w13_qzeros )
348+ set_weight_attrs (w13_qzeros , extra_weight_attrs )
349+
350+ w2_qzeros = Parameter (torch .empty (num_experts ,
351+ num_groups_w2 ,
352+ hidden_size //
353+ self .quant_config .pack_factor ,
354+ dtype = torch .int32 ),
355+ requires_grad = False )
356+ layer .register_parameter ("w2_qzeros" , w2_qzeros )
357+ set_weight_attrs (w2_qzeros , extra_weight_attrs )
358+
359+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
360+ num_experts = layer .w13_qweight .shape [0 ]
361+ device = layer .w13_qweight .device
362+
363+ layer .w13_g_idx_sort_indices = torch .nn .Parameter (
364+ torch .empty ((num_experts , 0 ), dtype = torch .int32 , device = device ),
365+ requires_grad = False ,
366+ )
367+ layer .w2_g_idx_sort_indices = torch .nn .Parameter (
368+ torch .empty ((num_experts , 0 ), dtype = torch .int32 , device = device ),
369+ requires_grad = False ,
370+ )
371+
372+ marlin_w13_qweight = ops .awq_marlin_moe_repack (
373+ layer .w13_qweight ,
374+ layer .w13_g_idx_sort_indices ,
375+ size_k = layer .w13_qweight .shape [1 ],
376+ size_n = layer .w13_qweight .shape [2 ] * self .quant_config .pack_factor ,
377+ num_bits = self .quant_config .weight_bits ,
378+ )
379+ replace_parameter (layer , "w13_qweight" , marlin_w13_qweight )
380+
381+ marlin_w2_qweight = ops .awq_marlin_moe_repack (
382+ layer .w2_qweight ,
383+ layer .w2_g_idx_sort_indices ,
384+ size_k = layer .w2_qweight .shape [1 ],
385+ size_n = layer .w2_qweight .shape [2 ] * self .quant_config .pack_factor ,
386+ num_bits = self .quant_config .weight_bits ,
387+ )
388+ replace_parameter (layer , "w2_qweight" , marlin_w2_qweight )
389+
390+ # Why does this take the intermediate size for size_k?
391+ marlin_w13_scales = marlin_moe_permute_scales (
392+ s = layer .w13_scales ,
393+ size_k = layer .intermediate_size_per_partition ,
394+ size_n = layer .w13_scales .shape [2 ],
395+ group_size = self .quant_config .group_size ,
396+ )
397+
398+ replace_parameter (layer , "w13_scales" , marlin_w13_scales )
399+
400+ marlin_w2_scales = marlin_moe_permute_scales (
401+ s = layer .w2_scales ,
402+ size_k = layer .intermediate_size_per_partition ,
403+ size_n = layer .w2_scales .shape [2 ],
404+ group_size = self .quant_config .group_size ,
405+ )
406+ replace_parameter (layer , "w2_scales" , marlin_w2_scales )
407+
408+ marlin_w13_zp = moe_awq_to_marlin_zero_points (
409+ layer .w13_qzeros ,
410+ size_k = layer .w13_qzeros .shape [1 ],
411+ size_n = layer .w13_qzeros .shape [2 ] * self .quant_config .pack_factor ,
412+ num_bits = self .quant_config .weight_bits )
413+ replace_parameter (layer , "w13_qzeros" , marlin_w13_zp )
414+
415+ marlin_w2_zp = moe_awq_to_marlin_zero_points (
416+ layer .w2_qzeros ,
417+ size_k = layer .w2_qzeros .shape [1 ],
418+ size_n = layer .w2_qzeros .shape [2 ] * self .quant_config .pack_factor ,
419+ num_bits = self .quant_config .weight_bits )
420+ replace_parameter (layer , "w2_qzeros" , marlin_w2_zp )
421+
422+ def apply (
423+ self ,
424+ layer : torch .nn .Module ,
425+ x : torch .Tensor ,
426+ router_logits : torch .Tensor ,
427+ top_k : int ,
428+ renormalize : bool = True ,
429+ use_grouped_topk : bool = False ,
430+ num_expert_group : Optional [int ] = None ,
431+ topk_group : Optional [int ] = None ,
432+ custom_routing_function : Optional [Callable ] = None ,
433+ ) -> torch .Tensor :
434+
435+ from vllm .model_executor .layers .fused_moe .fused_marlin_moe import (
436+ fused_marlin_moe )
437+
438+ topk_weights , topk_ids = FusedMoE .select_experts (
439+ hidden_states = x ,
440+ router_logits = router_logits ,
441+ use_grouped_topk = use_grouped_topk ,
442+ top_k = top_k ,
443+ renormalize = renormalize ,
444+ topk_group = topk_group ,
445+ num_expert_group = num_expert_group ,
446+ custom_routing_function = custom_routing_function )
447+
448+ return fused_marlin_moe (
449+ x ,
450+ layer .w13_qweight ,
451+ layer .w2_qweight ,
452+ layer .w13_scales ,
453+ layer .w2_scales ,
454+ router_logits ,
455+ topk_weights ,
456+ topk_ids ,
457+ w1_zeros = layer .w13_qzeros ,
458+ w2_zeros = layer .w2_qzeros ,
459+ num_bits = self .quant_config .weight_bits ,
460+ )
0 commit comments