11# SPDX-License-Identifier: Apache-2.0
22
3- from typing import Any , Dict , List , Optional
3+ from typing import Any , Callable , Dict , List , Optional
44
55import gguf
66import torch
77from gguf import GGMLQuantizationType as WeightType
88from torch .nn .parameter import Parameter , UninitializedParameter
99
1010from vllm import _custom_ops as ops
11+ from vllm .model_executor .layers .activation import SiluAndMul
12+ from vllm .model_executor .layers .fused_moe .layer import (FusedMoE ,
13+ FusedMoEMethodBase )
1114from vllm .model_executor .layers .linear import LinearBase , LinearMethodBase
1215from vllm .model_executor .layers .quantization .base_config import (
1316 QuantizationConfig , QuantizeMethodBase )
@@ -29,7 +32,7 @@ def get_name(self) -> str:
2932 return "gguf"
3033
3134 def get_supported_act_dtypes (self ) -> List [torch .dtype ]:
32- return [torch .half , torch . bfloat16 ]
35+ return [torch .half ]
3336
3437 @classmethod
3538 def get_min_capability (cls ) -> int :
@@ -49,6 +52,8 @@ def get_quant_method(self, layer: torch.nn.Module,
4952 return GGUFLinearMethod (self )
5053 elif isinstance (layer , VocabParallelEmbedding ):
5154 return GGUFEmbeddingMethod (self )
55+ elif isinstance (layer , FusedMoE ):
56+ return GGUFMoEMethod (self )
5257 return None
5358
5459
@@ -184,6 +189,124 @@ def apply(self,
184189 return out
185190
186191
192+ class GGUFMoEMethod (FusedMoEMethodBase ):
193+ """MoE method for GGUF.
194+
195+ Args:
196+ quant_config: The GGUF quantization config.
197+ """
198+
199+ def __init__ (self , quant_config : GGUFConfig ):
200+ self .quant_config = quant_config
201+
202+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
203+ hidden_size : int , intermediate_size_per_partition : int ,
204+ params_dtype : torch .dtype , ** extra_weight_attrs ):
205+
206+ tensor_shape = (num_experts , 2 * intermediate_size_per_partition ,
207+ hidden_size )
208+ #gate up proj
209+ w13_qweight = GGUFUninitializedParameter (requires_grad = False )
210+ set_weight_attrs (
211+ w13_qweight , {
212+ "input_dim" : 1 ,
213+ "output_dim" : 0 ,
214+ "tensor_shape" : tensor_shape ,
215+ "is_gguf_weight" : True ,
216+ "data_container" : [],
217+ })
218+ set_weight_attrs (w13_qweight , extra_weight_attrs )
219+ layer .register_parameter ("w13_qweight" , w13_qweight )
220+
221+ w13_qweight_type = Parameter (torch .empty (1 , dtype = torch .uint8 ),
222+ requires_grad = False )
223+ set_weight_attrs (w13_qweight_type , {
224+ "is_gguf_weight_type" : True ,
225+ "weight_type" : 0 ,
226+ "ignore_warning" : True
227+ })
228+ set_weight_attrs (w13_qweight_type , extra_weight_attrs )
229+ layer .register_parameter ("w13_qweight_type" , w13_qweight_type )
230+
231+ tensor_shape = (num_experts , intermediate_size_per_partition ,
232+ hidden_size )
233+ #gate down proj
234+ w2_qweight = GGUFUninitializedParameter (requires_grad = False )
235+ set_weight_attrs (
236+ w2_qweight , {
237+ "input_dim" : 1 ,
238+ "output_dim" : 0 ,
239+ "tensor_shape" : tensor_shape ,
240+ "is_gguf_weight" : True ,
241+ "data_container" : [],
242+ })
243+ set_weight_attrs (w2_qweight , extra_weight_attrs )
244+ layer .register_parameter ("w2_qweight" , w2_qweight )
245+
246+ w2_qweight_type = Parameter (torch .empty (1 , dtype = torch .uint8 ),
247+ requires_grad = False )
248+ set_weight_attrs (w2_qweight_type , {
249+ "is_gguf_weight_type" : True ,
250+ "weight_type" : 0 ,
251+ "ignore_warning" : True
252+ })
253+
254+ set_weight_attrs (w2_qweight_type , extra_weight_attrs )
255+ layer .register_parameter ("w2_qweight_type" , w2_qweight_type )
256+ self .act = SiluAndMul ()
257+
258+ def apply (
259+ self ,
260+ layer : torch .nn .Module ,
261+ x : torch .Tensor ,
262+ router_logits : torch .Tensor ,
263+ top_k : int ,
264+ renormalize : bool ,
265+ use_grouped_topk : bool = False ,
266+ topk_group : Optional [int ] = None ,
267+ num_expert_group : Optional [int ] = None ,
268+ global_num_experts : int = - 1 ,
269+ expert_map : Optional [torch .Tensor ] = None ,
270+ custom_routing_function : Optional [Callable ] = None ,
271+ scoring_func : str = "softmax" ,
272+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
273+ activation : str = "silu" ,
274+ ):
275+ assert activation == "silu" , "Only SiLU activation is supported."
276+ topk_weights , topk_ids = FusedMoE .select_experts (
277+ hidden_states = x ,
278+ router_logits = router_logits ,
279+ use_grouped_topk = use_grouped_topk ,
280+ top_k = top_k ,
281+ renormalize = renormalize ,
282+ topk_group = topk_group ,
283+ num_expert_group = num_expert_group ,
284+ custom_routing_function = custom_routing_function ,
285+ scoring_func = scoring_func ,
286+ e_score_correction_bias = e_score_correction_bias )
287+ final_hidden_states = torch .empty_like (x )
288+ for tok , (w , idx ) in enumerate (zip (topk_weights , topk_ids )):
289+ inp = x [tok ].reshape ((1 , ) + x .shape [1 :])
290+ current_hidden_state = None
291+ for ww , ii in zip (w , idx ):
292+ expert_up = layer .w13_qweight [ii ]
293+
294+ out = _fuse_mul_mat (inp , expert_up ,
295+ layer .w13_qweight_type .weight_type )
296+ out = self .act (out )
297+
298+ expert_down = layer .w2_qweight [ii ]
299+ current_state = _fuse_mul_mat (
300+ out , expert_down ,
301+ layer .w2_qweight_type .weight_type ).mul_ (ww )
302+ if current_hidden_state is None :
303+ current_hidden_state = current_state
304+ else :
305+ current_hidden_state .add_ (current_state )
306+ final_hidden_states [tok ] = current_hidden_state
307+ return final_hidden_states
308+
309+
187310class GGUFEmbeddingMethod (GGUFLinearMethod ):
188311 """Embedding method for GGUF.
189312
0 commit comments