99
1010from vllm import _custom_ops as ops
1111from vllm .logger import init_logger
12- from vllm .model_executor .layers .activation import SiluAndMul
1312from vllm .model_executor .layers .fused_moe .layer import (FusedMoE ,
1413 FusedMoEMethodBase )
1514from vllm .model_executor .layers .linear import LinearBase , LinearMethodBase
1918from vllm .model_executor .layers .vocab_parallel_embedding import (
2019 VocabParallelEmbedding )
2120from vllm .model_executor .utils import set_weight_attrs
21+ from vllm .utils import direct_register_custom_op
2222
2323logger = init_logger (__name__ )
2424
@@ -96,8 +96,8 @@ def get_quant_method(self, layer: torch.nn.Module,
9696MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
9797
9898
99- def _fuse_mul_mat (x : torch .Tensor , qweight : torch .Tensor ,
100- qweight_type : int ) -> torch .Tensor :
99+ def _fused_mul_mat_gguf (x : torch .Tensor , qweight : torch .Tensor ,
100+ qweight_type : int ) -> torch .Tensor :
101101 # HACK: when doing chunked prefill we don't generate output tokens
102102 # so input to logits generator is empty which causes invalid parameter
103103 if x .shape [0 ] == 0 :
@@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
130130 return y
131131
132132
133+ def _fused_mul_mat_gguf_fake (
134+ x : torch .Tensor ,
135+ qweight : torch .Tensor ,
136+ qweight_type : int ,
137+ ) -> torch .Tensor :
138+ return torch .empty (x .shape [0 ],
139+ qweight .shape [0 ],
140+ dtype = x .dtype ,
141+ device = x .device )
142+
143+
144+ try :
145+ direct_register_custom_op (
146+ op_name = "_fused_mul_mat_gguf" ,
147+ op_func = _fused_mul_mat_gguf ,
148+ mutates_args = [],
149+ fake_impl = _fused_mul_mat_gguf_fake ,
150+ )
151+ fused_mul_mat_gguf = torch .ops .vllm ._fused_mul_mat_gguf
152+
153+ except AttributeError as error :
154+ raise error
155+
156+
133157def _fused_moe_gguf (
134158 x : torch .Tensor ,
135159 w1 : torch .Tensor ,
@@ -138,8 +162,21 @@ def _fused_moe_gguf(
138162 topk_ids : torch .Tensor ,
139163 qweight_type : int ,
140164 qweight_type2 : int ,
141- act ,
165+ activation : str ,
142166) -> torch .Tensor :
167+
168+ def act (x : torch .Tensor ):
169+ d = x .shape [- 1 ] // 2
170+ output_shape = (x .shape [:- 1 ] + (d , ))
171+ out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
172+ if activation == "silu" :
173+ torch .ops ._C .silu_and_mul (out , x )
174+ elif activation == "gelu" :
175+ torch .ops ._C .gelu_and_mul (out , x )
176+ else :
177+ raise ValueError (f"Unsupported activation: { activation } " )
178+ return out
179+
143180 # lazy import to avoid triggering triton import in CPU backend
144181 from vllm .model_executor .layers .fused_moe .fused_moe import (
145182 moe_align_block_size )
@@ -189,12 +226,12 @@ def _fused_moe_gguf(
189226 for ww , ii in zip (w , idx ):
190227 expert_up = w1 [ii ]
191228
192- out = _fuse_mul_mat (inp , expert_up , qweight_type )
229+ out = fused_mul_mat_gguf (inp , expert_up , qweight_type )
193230 out = act (out )
194231
195232 expert_down = w2 [ii ]
196- current_state = _fuse_mul_mat (out , expert_down ,
197- qweight_type2 ).mul_ (ww )
233+ current_state = fused_mul_mat_gguf (out , expert_down ,
234+ qweight_type2 ).mul_ (ww )
198235 if current_hidden_state is None :
199236 current_hidden_state = current_state
200237 else :
@@ -203,6 +240,78 @@ def _fused_moe_gguf(
203240 return out_hidden_states
204241
205242
243+ def _fused_moe_gguf_fake (
244+ x : torch .Tensor ,
245+ w1 : torch .Tensor ,
246+ w2 : torch .Tensor ,
247+ topk_weights : torch .Tensor ,
248+ topk_ids : torch .Tensor ,
249+ qweight_type : int ,
250+ qweight_type2 : int ,
251+ activation : str ,
252+ ) -> torch .Tensor :
253+ return torch .empty_like (x )
254+
255+
256+ try :
257+ direct_register_custom_op (
258+ op_name = "_fused_moe_gguf" ,
259+ op_func = _fused_moe_gguf ,
260+ mutates_args = [],
261+ fake_impl = _fused_moe_gguf_fake ,
262+ )
263+ fused_moe_gguf = torch .ops .vllm ._fused_moe_gguf
264+
265+ except AttributeError as error :
266+ raise error
267+
268+
269+ def _apply_gguf_embedding (
270+ x : torch .Tensor ,
271+ qweight : torch .Tensor ,
272+ qweight_type : int ,
273+ hidden_size : int ,
274+ dtype : Optional [torch .dtype ] = None ,
275+ ) -> torch .Tensor :
276+ if qweight_type in UNQUANTIZED_TYPES :
277+ return torch .embedding (qweight , x )
278+ elif qweight_type in DEQUANT_TYPES :
279+ block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
280+ x_flat = x .flatten ()
281+ assert (hidden_size == qweight .shape [1 ] // type_size * block_size )
282+ quant = torch .index_select (qweight , dim = 0 , index = x_flat )
283+ dequant = ops .ggml_dequantize (quant , qweight_type , hidden_size ,
284+ x_flat .shape [0 ], dtype )
285+ return dequant .view (* x .shape , hidden_size )
286+ else :
287+ qweight_type = WeightType (qweight_type )
288+ raise NotImplementedError (
289+ f"Unsupported GGUF quantization type: { qweight_type } " )
290+
291+
292+ def _apply_gguf_embedding_fake (
293+ x : torch .Tensor ,
294+ qweight : torch .Tensor ,
295+ qweight_type : int ,
296+ hidden_size : int ,
297+ dtype : Optional [torch .dtype ] = None ,
298+ ) -> torch .Tensor :
299+ return torch .empty (x .shape [0 ], hidden_size , dtype = dtype , device = x .device )
300+
301+
302+ try :
303+ direct_register_custom_op (
304+ op_name = "_apply_gguf_embedding" ,
305+ op_func = _apply_gguf_embedding ,
306+ mutates_args = [],
307+ fake_impl = _apply_gguf_embedding_fake ,
308+ )
309+ apply_gguf_embedding = torch .ops .vllm ._apply_gguf_embedding
310+
311+ except AttributeError as error :
312+ raise error
313+
314+
206315class GGUFLinearMethod (LinearMethodBase ):
207316 """Linear method for GGUF.
208317
@@ -249,26 +358,76 @@ def create_weights(self, layer: torch.nn.Module,
249358 set_weight_attrs (qweight_type , extra_weight_attrs )
250359 layer .register_parameter ("qweight_type" , qweight_type )
251360
361+ def process_weights_after_loading (self , layer : torch .nn .Module ):
362+ qweight_type = layer .qweight_type .weight_type
363+ if not (qweight_type in UNQUANTIZED_TYPES
364+ or qweight_type in DEQUANT_TYPES ):
365+ qweight_type = WeightType (qweight_type )
366+ raise ValueError (
367+ f"Unsupported GGUF quantization type { qweight_type } in "
368+ f"layer { layer } ." )
369+ # For MergedColumnParallelLinear and QKVParallelLinear, we need to
370+ # materialize the padded weight parameter for CUDA Graph compatibility.
371+ self ._create_padded_weight_param (layer )
372+
373+ def _create_padded_weight_param (self , layer : torch .nn .Module ):
374+ """Create padded weight parameter for GGUF MergedLinear layer."""
375+ qweight = layer .qweight
376+ shard_id_map = qweight .shard_id_map
377+ shard_id = qweight .shard_id
378+ if len (data_container := qweight .data_container ) > 1 :
379+ dtype = {data .dtype for data in data_container }
380+ assert len (dtype ) == 1 , ValueError (
381+ f"Data container has mixed dtypes: { dtype } " )
382+ dtype = next (iter (dtype ))
383+ # concat dim0 and pad dim1
384+ padded_side = max (x .size (1 ) for x in data_container )
385+ concat_side = sum (x .size (0 ) for x in data_container )
386+ # Pad the quantized weights to dense tensor, and create a map
387+ # with the location of each shard in the padded tensor.
388+ padded_data = torch .zeros ((concat_side , padded_side ),
389+ dtype = dtype ,
390+ device = qweight .device )
391+ # (dim0_start, dim0_end, dim1_size)
392+ shard_offset_map = dict [str , tuple [int , int , int ]]()
393+ for idx in shard_id :
394+ id_in_container = shard_id_map [idx ]
395+ start = sum (
396+ x .size (0 ) for x in data_container [:id_in_container ])
397+ end = start + data_container [id_in_container ].size (0 )
398+ size = data_container [id_in_container ].size (1 )
399+ padded_data [start :end , :size ] = data_container [id_in_container ]
400+ shard_offset_map [idx ] = (start , end , size )
401+ qweight .data_container .clear ()
402+ padded_param = Parameter (padded_data , requires_grad = False )
403+ set_weight_attrs (padded_param , vars (qweight ))
404+ set_weight_attrs (padded_param ,
405+ {"shard_offset_map" : shard_offset_map })
406+ layer .register_parameter ("qweight" , padded_param )
407+
252408 def apply (self ,
253409 layer : torch .nn .Module ,
254410 x : torch .Tensor ,
255411 bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
256- shard_id = getattr ( layer .qweight , " shard_id" , None )
412+ shard_id = layer .qweight . shard_id
257413
258414 if shard_id :
259415 # dequantize shard weights respectively
260416 shard_id = ["q" , "k" , "v" ] if "q" in shard_id else shard_id
261- qweight = layer .qweight . unbind ( 0 )
417+ qweight = layer .qweight
262418 result = []
263419 for idx in shard_id :
264- q_idx = layer .qweight .shard_id_map [idx ]
420+ start , end , offset = layer .qweight .shard_offset_map [idx ]
265421 qweight_type = layer .qweight_type .shard_weight_type [idx ]
266- result .append (_fuse_mul_mat (x , qweight [q_idx ], qweight_type ))
422+ result .append (
423+ fused_mul_mat_gguf (
424+ x , qweight [start :end , :offset ].contiguous (),
425+ qweight_type ))
267426 out = torch .cat (result , axis = 1 )
268427 else :
269428 qweight = layer .qweight
270429 qweight_type = layer .qweight_type .weight_type
271- out = _fuse_mul_mat (x , qweight , qweight_type )
430+ out = fused_mul_mat_gguf (x , qweight , qweight_type )
272431 if bias is not None :
273432 out .add_ (bias )
274433 return out
@@ -338,7 +497,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
338497
339498 set_weight_attrs (w2_qweight_type , extra_weight_attrs )
340499 layer .register_parameter ("w2_qweight_type" , w2_qweight_type )
341- self .act = SiluAndMul ()
342500
343501 def apply (
344502 self ,
@@ -375,10 +533,10 @@ def apply(
375533 custom_routing_function = custom_routing_function ,
376534 scoring_func = scoring_func ,
377535 e_score_correction_bias = e_score_correction_bias )
378- return _fused_moe_gguf (x , layer .w13_qweight , layer .w2_qweight ,
379- topk_weights , topk_ids ,
380- layer .w13_qweight_type .weight_type ,
381- layer .w2_qweight_type .weight_type , self . act )
536+ return fused_moe_gguf (x , layer .w13_qweight , layer .w2_qweight ,
537+ topk_weights , topk_ids ,
538+ layer .w13_qweight_type .weight_type ,
539+ layer .w2_qweight_type .weight_type , activation )
382540
383541
384542class GGUFEmbeddingMethod (GGUFLinearMethod ):
@@ -392,34 +550,15 @@ def embedding(self, layer: torch.nn.Module,
392550 x : torch .Tensor ) -> torch .Tensor :
393551 qweight = layer .qweight
394552 qweight_type = layer .qweight_type .weight_type
553+ hidden_size = qweight .tensor_shape [1 ]
395554
396- block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
397- hidden_size = qweight .shape [1 ] // type_size * block_size
398- if qweight_type < 2 :
399- return torch .embedding (qweight , x )
400- x_flat = x .flatten ()
401- quant = torch .index_select (qweight , dim = 0 , index = x_flat )
402- dequant = ops .ggml_dequantize (quant , qweight_type , hidden_size ,
403- x_flat .shape [0 ], self .params_dtype )
404- return dequant .view (* x .shape , hidden_size )
555+ return apply_gguf_embedding (x ,
556+ qweight ,
557+ qweight_type ,
558+ hidden_size ,
559+ dtype = self .params_dtype )
405560
406561
407562class GGUFUninitializedParameter (UninitializedParameter ):
408563 cls_to_become = Parameter
409564 data_container : list [torch .Tensor ]
410-
411- def materialize_nested (self ) -> Parameter :
412- dtype = {data .dtype for data in self .data_container }
413- assert len (dtype ) == 1 , ValueError (
414- f"Data container has mixed dtypes: { dtype } " )
415- dtype = next (iter (dtype ))
416- nested_data = torch .nested .nested_tensor (self .data_container ,
417- device = self .device ,
418- dtype = dtype )
419- self .data_container .clear ()
420- param = torch .Tensor ._make_subclass (self .cls_to_become ,
421- nested_data ,
422- require_grad = False )
423- for k , v in self .__dict__ .items ():
424- setattr (param , k , v )
425- return param
0 commit comments