Skip to content

Commit b3587d7

Browse files
mengniwang95Liangyx2
authored andcommitted
Support loading 4 bit Qwen2 (huggingface#1476)
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
1 parent 161aafd commit b3587d7

1 file changed

Lines changed: 19 additions & 4 deletions

File tree

optimum/habana/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,22 @@ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
198198
self.block_size = 4096
199199
self.rotary_emb = GaudiRotaryEmbedding(config=self.config)
200200

201+
def get_k_proj_weight(self):
202+
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight."""
203+
if hasattr(self.k_proj, "qweight"):
204+
return self.k_proj.qweight
205+
return self.k_proj.weight
206+
207+
def get_k_proj_weight_dtype(self):
208+
"""4bit quantization in GPTQ replaces the k_proj.weight with qweight.
209+
Scales tensor gets the weight dtype."""
210+
if hasattr(self.k_proj, "qweight"):
211+
return self.k_proj.scales.dtype
212+
return self.k_proj.weight.dtype
213+
201214
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
202215
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
203-
device = self.k_proj.weight.device
216+
device = self.get_k_proj_weight().device
204217
dtype = self.config.torch_dtype
205218
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
206219
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
@@ -211,7 +224,7 @@ def update_sincos_cache(self, seq_len):
211224
# reduce memory consumption and improve performance.
212225
if seq_len > self.max_position_embeddings:
213226
self.max_position_embeddings = seq_len
214-
_, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len)
227+
_, _ = self.rotary_emb(self.get_k_proj_weight(), seq_len=seq_len)
215228

216229
def reorder(self, tensor, beam_idx, dim_a, dim_b):
217230
updated = tensor.index_select(0, beam_idx)
@@ -316,9 +329,11 @@ def pre_attn_forward(
316329
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
317330
else:
318331
if past_key_value is None:
319-
past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
332+
past_key = torch.zeros(
333+
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
334+
)
320335
past_value = torch.zeros(
321-
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
336+
key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device
322337
)
323338
past_key_value = [past_key, past_value]
324339
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)

0 commit comments

Comments
 (0)