Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions unsloth/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,19 @@ def Qwen3Attention_fast_forward(
assert(n_kv_heads * n_groups == n_heads)

Q, K, V = self.apply_qkv(self, hidden_states)
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
Q = Q.view(bsz, q_len, n_heads, head_dim)#.transpose(1, 2) # we will transpose after normalisation
K = K.view(bsz, q_len, n_kv_heads, head_dim)#.transpose(1, 2) # we will transpose after normalisation
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

#Qwen3 has QKNorm. This seems to be the only difference from Qwen2.
Q = fast_layernorm_compiled(self.q_norm, Q)
K = fast_layernorm_compiled(self.k_norm, K)
# Note that using fast_layernorm_compiled causes issues as the dimensions don't match up.
# I tried to add a compiled version of the new norm but the numbers don't match up with Transformers
# TODO: Check on the differences here.
Q = fast_rms_layernorm(self.q_norm, Q)
K = fast_rms_layernorm(self.k_norm, K)

Q = Q.transpose(1, 2)
K = K.transpose(1, 2)

kv_seq_len = K.shape[-2]
if past_key_value is not None:
Expand Down