3636class GGUFConfig (QuantizationConfig ):
3737 """Config class for GGUF."""
3838
39- def __init__ (self , unquantized_modules : list [str ] | None = None ) -> None :
39+ def __init__ (self , unquantized_modules : list [str ] | None = None ,
40+ model_arch : str = "" ) -> None :
4041 super ().__init__ ()
4142 self .unquantized_modules = unquantized_modules or []
43+ self .model_arch = model_arch
4244
4345 def __repr__ (self ) -> str :
4446 return "GGUFConfig()"
@@ -59,17 +61,19 @@ def get_config_filenames(cls) -> list[str]:
5961
6062 @classmethod
6163 def from_config (cls , config : dict [str , Any ]) -> "GGUFConfig" :
62- return cls ()
64+ # Extract model_arch from config if available
65+ model_arch = config .get ("model_arch" , "" )
66+ return cls (model_arch = model_arch )
6367
6468 def get_quant_method (
6569 self , layer : torch .nn .Module , prefix : str
6670 ) -> Optional ["QuantizeMethodBase" ]:
6771 if isinstance (layer , LinearBase ):
6872 if is_layer_skipped_gguf (prefix , self .unquantized_modules ):
6973 return UnquantizedLinearMethod ()
70- return GGUFLinearMethod (self )
74+ return GGUFLinearMethod (self , self . model_arch )
7175 elif isinstance (layer , VocabParallelEmbedding ):
72- return GGUFEmbeddingMethod (self )
76+ return GGUFEmbeddingMethod (self , self . model_arch )
7377 elif isinstance (layer , FusedMoE ):
7478 return GGUFMoEMethod (self , layer .moe_config )
7579 return None
@@ -115,7 +119,8 @@ def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
115119
116120
117121def _fused_mul_mat_gguf (
118- x : torch .Tensor , qweight : torch .Tensor , qweight_type : int
122+ x : torch .Tensor , qweight : torch .Tensor , qweight_type : int ,
123+ target_dtype : Optional [torch .dtype ] = None
119124) -> torch .Tensor :
120125 if qweight_type in IMATRIX_QUANT_TYPES :
121126 mmvq_safe = 8 if qweight .shape [0 ] > 5120 else 16
@@ -125,9 +130,15 @@ def _fused_mul_mat_gguf(
125130 # so input to logits generator is empty which causes invalid parameter
126131 if x .shape [0 ] == 0 :
127132 return torch .empty (x .shape [0 ], qweight .shape [0 ], dtype = x .dtype , device = x .device )
128- # there is no need to call any kernel for fp16/bf16
133+ # Unquantized weights (F16/BF16) use direct matrix multiplication
129134 if qweight_type in UNQUANTIZED_TYPES :
130- return x @ qweight .T
135+ # GGUF stores some weights as F16, but the model may use bfloat16 for
136+ # computation. Convert weight dtype to match target_dtype (typically
137+ # params_dtype=bfloat16) to ensure consistency in mixed-precision.
138+ weight = qweight
139+ if target_dtype is not None and weight .dtype != target_dtype :
140+ weight = weight .to (target_dtype )
141+ return x @ weight .T
131142 # enable MMVQ in contiguous batching with batch_size=1
132143 if x .shape [0 ] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES :
133144 y = ops .ggml_mul_mat_vec_a8 (qweight , x , qweight_type , qweight .shape [0 ])
@@ -138,7 +149,9 @@ def _fused_mul_mat_gguf(
138149 elif qweight_type in DEQUANT_TYPES :
139150 block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
140151 shape = (qweight .shape [0 ], qweight .shape [1 ] // type_size * block_size )
141- weight = ops .ggml_dequantize (qweight , qweight_type , * shape , x .dtype )
152+ # Use target_dtype if provided, otherwise fall back to x.dtype
153+ dequant_dtype = target_dtype if target_dtype is not None else x .dtype
154+ weight = ops .ggml_dequantize (qweight , qweight_type , * shape , dequant_dtype )
142155 y = x @ weight .T
143156 else :
144157 # Raise an error if the quantization type is not supported.
@@ -153,6 +166,7 @@ def _fused_mul_mat_gguf_fake(
153166 x : torch .Tensor ,
154167 qweight : torch .Tensor ,
155168 qweight_type : int ,
169+ target_dtype : Optional [torch .dtype ] = None ,
156170) -> torch .Tensor :
157171 return torch .empty (x .shape [0 ], qweight .shape [0 ], dtype = x .dtype , device = x .device )
158172
@@ -311,7 +325,13 @@ def _apply_gguf_embedding(
311325 dtype : torch .dtype | None = None ,
312326) -> torch .Tensor :
313327 if qweight_type in UNQUANTIZED_TYPES :
314- return torch .embedding (qweight , x )
328+ # torch.embedding preserves weight tensor dtype (F16 for GGUF).
329+ # Convert result to model's computation dtype (typically bfloat16)
330+ # to maintain consistency throughout forward pass.
331+ result = torch .embedding (qweight , x )
332+ if dtype is not None and result .dtype != dtype :
333+ result = result .to (dtype )
334+ return result
315335 elif qweight_type in DEQUANT_TYPES :
316336 block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
317337 x_flat = x .flatten ()
@@ -355,7 +375,9 @@ class GGUFLinearMethod(LinearMethodBase):
355375 quant_config: The GGUF quantization config.
356376 """
357377
358- def __init__ (self , quant_config : GGUFConfig ):
378+ def __init__ (self , quant_config : GGUFConfig , model_arch : str = "" ):
379+ self .model_arch = model_arch .lower ()
380+ self .is_gemma3 = "gemma3" in self .model_arch
359381 self .quant_config = quant_config
360382
361383 def create_weights (
@@ -369,6 +391,9 @@ def create_weights(
369391 ** extra_weight_attrs ,
370392 ):
371393 self .params_dtype = params_dtype
394+ # Gemma3 isolation: only apply dtype conversion for Gemma3 models
395+ self .target_dtype = params_dtype if self .is_gemma3 else None
396+
372397 output_size_per_partition = sum (output_partition_sizes )
373398
374399 tensor_shape = (output_size_per_partition , input_size_per_partition )
@@ -467,14 +492,16 @@ def apply(
467492 qweight_type = layer .qweight_type .shard_weight_type [idx ]
468493 result .append (
469494 fused_mul_mat_gguf (
470- x , qweight [start :end , :offset ].contiguous (), qweight_type
495+ x , qweight [start :end , :offset ].contiguous (), qweight_type ,
496+ self .target_dtype
471497 )
472498 )
473499 out = torch .cat (result , axis = 1 )
474500 else :
475501 qweight = layer .qweight
476502 qweight_type = layer .qweight_type .weight_type
477- out = fused_mul_mat_gguf (x , qweight , qweight_type )
503+ # Gemma3 isolation: only apply dtype conversion for Gemma3
504+ out = fused_mul_mat_gguf (x , qweight , qweight_type , self .target_dtype )
478505 if bias is not None :
479506 out .add_ (bias )
480507 return out
0 commit comments