3636class GGUFConfig (QuantizationConfig ):
3737 """Config class for GGUF."""
3838
39- def __init__ (self , unquantized_modules : list [str ] | None = None ) -> None :
39+ def __init__ (
40+ self , unquantized_modules : list [str ] | None = None , model_arch : str = ""
41+ ) -> None :
4042 super ().__init__ ()
4143 self .unquantized_modules = unquantized_modules or []
44+ self .model_arch = model_arch
4245
4346 def __repr__ (self ) -> str :
4447 return "GGUFConfig()"
@@ -59,17 +62,19 @@ def get_config_filenames(cls) -> list[str]:
5962
6063 @classmethod
6164 def from_config (cls , config : dict [str , Any ]) -> "GGUFConfig" :
62- return cls ()
65+ # Extract model_arch from config if available
66+ model_arch = config .get ("model_arch" , "" )
67+ return cls (model_arch = model_arch )
6368
6469 def get_quant_method (
6570 self , layer : torch .nn .Module , prefix : str
6671 ) -> Optional ["QuantizeMethodBase" ]:
6772 if isinstance (layer , LinearBase ):
6873 if is_layer_skipped_gguf (prefix , self .unquantized_modules ):
6974 return UnquantizedLinearMethod ()
70- return GGUFLinearMethod (self )
75+ return GGUFLinearMethod (self , self . model_arch )
7176 elif isinstance (layer , VocabParallelEmbedding ):
72- return GGUFEmbeddingMethod (self )
77+ return GGUFEmbeddingMethod (self , self . model_arch )
7378 elif isinstance (layer , FusedMoE ):
7479 return GGUFMoEMethod (self , layer .moe_config )
7580 return None
@@ -115,7 +120,10 @@ def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
115120
116121
117122def _fused_mul_mat_gguf (
118- x : torch .Tensor , qweight : torch .Tensor , qweight_type : int
123+ x : torch .Tensor ,
124+ qweight : torch .Tensor ,
125+ qweight_type : int ,
126+ target_dtype : torch .dtype | None = None ,
119127) -> torch .Tensor :
120128 if qweight_type in IMATRIX_QUANT_TYPES :
121129 mmvq_safe = 8 if qweight .shape [0 ] > 5120 else 16
@@ -125,9 +133,15 @@ def _fused_mul_mat_gguf(
125133 # so input to logits generator is empty which causes invalid parameter
126134 if x .shape [0 ] == 0 :
127135 return torch .empty (x .shape [0 ], qweight .shape [0 ], dtype = x .dtype , device = x .device )
128- # there is no need to call any kernel for fp16/bf16
136+ # Unquantized weights (F16/BF16) use direct matrix multiplication
129137 if qweight_type in UNQUANTIZED_TYPES :
130- return x @ qweight .T
138+ # GGUF stores some weights as F16, but the model may use bfloat16 for
139+ # computation. Convert weight dtype to match target_dtype (typically
140+ # params_dtype=bfloat16) to ensure consistency in mixed-precision.
141+ weight = qweight
142+ if target_dtype is not None and weight .dtype != target_dtype :
143+ weight = weight .to (target_dtype )
144+ return x @ weight .T
131145 # enable MMVQ in contiguous batching with batch_size=1
132146 if x .shape [0 ] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES :
133147 y = ops .ggml_mul_mat_vec_a8 (qweight , x , qweight_type , qweight .shape [0 ])
@@ -138,7 +152,9 @@ def _fused_mul_mat_gguf(
138152 elif qweight_type in DEQUANT_TYPES :
139153 block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
140154 shape = (qweight .shape [0 ], qweight .shape [1 ] // type_size * block_size )
141- weight = ops .ggml_dequantize (qweight , qweight_type , * shape , x .dtype )
155+ # Use target_dtype if provided, otherwise fall back to x.dtype
156+ dequant_dtype = target_dtype if target_dtype is not None else x .dtype
157+ weight = ops .ggml_dequantize (qweight , qweight_type , * shape , dequant_dtype )
142158 y = x @ weight .T
143159 else :
144160 # Raise an error if the quantization type is not supported.
@@ -153,6 +169,7 @@ def _fused_mul_mat_gguf_fake(
153169 x : torch .Tensor ,
154170 qweight : torch .Tensor ,
155171 qweight_type : int ,
172+ target_dtype : torch .dtype | None = None ,
156173) -> torch .Tensor :
157174 return torch .empty (x .shape [0 ], qweight .shape [0 ], dtype = x .dtype , device = x .device )
158175
@@ -311,7 +328,13 @@ def _apply_gguf_embedding(
311328 dtype : torch .dtype | None = None ,
312329) -> torch .Tensor :
313330 if qweight_type in UNQUANTIZED_TYPES :
314- return torch .embedding (qweight , x )
331+ # torch.embedding preserves weight tensor dtype (F16 for GGUF).
332+ # Convert result to model's computation dtype (typically bfloat16)
333+ # to maintain consistency throughout forward pass.
334+ result = torch .embedding (qweight , x )
335+ if dtype is not None and result .dtype != dtype :
336+ result = result .to (dtype )
337+ return result
315338 elif qweight_type in DEQUANT_TYPES :
316339 block_size , type_size = gguf .GGML_QUANT_SIZES [qweight_type ]
317340 x_flat = x .flatten ()
@@ -355,7 +378,9 @@ class GGUFLinearMethod(LinearMethodBase):
355378 quant_config: The GGUF quantization config.
356379 """
357380
358- def __init__ (self , quant_config : GGUFConfig ):
381+ def __init__ (self , quant_config : GGUFConfig , model_arch : str = "" ):
382+ self .model_arch = model_arch .lower ()
383+ self .is_gemma3 = "gemma3" in self .model_arch
359384 self .quant_config = quant_config
360385
361386 def create_weights (
@@ -369,6 +394,9 @@ def create_weights(
369394 ** extra_weight_attrs ,
370395 ):
371396 self .params_dtype = params_dtype
397+ # Gemma3 isolation: only apply dtype conversion for Gemma3 models
398+ self .target_dtype = params_dtype if self .is_gemma3 else None
399+
372400 output_size_per_partition = sum (output_partition_sizes )
373401
374402 tensor_shape = (output_size_per_partition , input_size_per_partition )
@@ -467,14 +495,18 @@ def apply(
467495 qweight_type = layer .qweight_type .shard_weight_type [idx ]
468496 result .append (
469497 fused_mul_mat_gguf (
470- x , qweight [start :end , :offset ].contiguous (), qweight_type
498+ x ,
499+ qweight [start :end , :offset ].contiguous (),
500+ qweight_type ,
501+ self .target_dtype ,
471502 )
472503 )
473504 out = torch .cat (result , axis = 1 )
474505 else :
475506 qweight = layer .qweight
476507 qweight_type = layer .qweight_type .weight_type
477- out = fused_mul_mat_gguf (x , qweight , qweight_type )
508+ # Gemma3 isolation: only apply dtype conversion for Gemma3
509+ out = fused_mul_mat_gguf (x , qweight , qweight_type , self .target_dtype )
478510 if bias is not None :
479511 out .add_ (bias )
480512 return out
0 commit comments