@@ -198,9 +198,22 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
198198 self .block_size = 4096
199199 self .rotary_emb = GaudiRotaryEmbedding (config = self .config )
200200
201+ def get_k_proj_weight (self ):
202+ """4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
203+ if hasattr (self .k_proj , "qweight" ):
204+ return self .k_proj .qweight
205+ return self .k_proj .weight
206+
207+ def get_k_proj_weight_dtype (self ):
208+ """4bit quantization in GPTQ replaces the k_proj.weight with qweight.
209+ Scales tensor gets the weight dtype."""
210+ if hasattr (self .k_proj , "qweight" ):
211+ return self .k_proj .scales .dtype
212+ return self .k_proj .weight .dtype
213+
201214 def allocate_kv_cache (self , batch_size , max_seq_len , inp_seq_len ):
202215 cache_shape = (batch_size , self .num_key_value_heads , max_seq_len , self .head_dim )
203- device = self .k_proj . weight .device
216+ device = self .get_k_proj_weight () .device
204217 dtype = self .config .torch_dtype
205218 self .k_cache .allocate (inp_seq_len , dtype , device , cache_shape )
206219 self .v_cache .allocate (inp_seq_len , dtype , device , cache_shape )
@@ -211,7 +224,7 @@ def update_sincos_cache(self, seq_len):
211224 # reduce memory consumption and improve performance.
212225 if seq_len > self .max_position_embeddings :
213226 self .max_position_embeddings = seq_len
214- _ , _ = self .rotary_emb (self .k_proj . weight , seq_len = seq_len )
227+ _ , _ = self .rotary_emb (self .get_k_proj_weight () , seq_len = seq_len )
215228
216229 def reorder (self , tensor , beam_idx , dim_a , dim_b ):
217230 updated = tensor .index_select (0 , beam_idx )
@@ -316,9 +329,11 @@ def pre_attn_forward(
316329 past_key_value = (self .k_cache .get_shape (), self .v_cache .get_shape ())
317330 else :
318331 if past_key_value is None :
319- past_key = torch .zeros (key_states .shape , dtype = self .k_proj .weight .dtype , device = key_states .device )
332+ past_key = torch .zeros (
333+ key_states .shape , dtype = self .get_k_proj_weight_dtype (), device = key_states .device
334+ )
320335 past_value = torch .zeros (
321- key_states .shape , dtype = self .k_proj . weight . dtype , device = key_states .device
336+ key_states .shape , dtype = self .get_k_proj_weight_dtype () , device = key_states .device
322337 )
323338 past_key_value = [past_key , past_value ]
324339 key_states = self .k_cache .update (past_key_value [0 ], key_states , 2 , token_idx , self .inp_seq_len )
0 commit comments