Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions python/sglang/srt/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,37 +240,43 @@ def __init__(
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0

return attn_scale.unsqueeze(-1)

@torch.compile(dynamic=True, backend=get_compiler_backend())
def _mul_attn_scale(self, positions, q):
attn_scale = self._get_attn_scale(positions)
return (q * attn_scale).to(q.dtype)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)

if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k)
q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
assert (q_out_unused is q_view) and (k_out_unused is k_view)
del q_view, k_view, q_out_unused, k_out_unused

if self.qk_norm is not None:
# TODO: support float
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
q = self.qk_norm(q).to(q.dtype)
k = self.qk_norm(k).to(k.dtype)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
qk = self.qk_norm(qk).to(torch.bfloat16)
qk = qk.reshape(-1, self.q_size + self.kv_size)

q, k = qk.split([self.q_size, self.kv_size], dim=-1)

# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
# while working at very long context
# https://arxiv.org/abs/2501.19399
if self.attn_temperature_tuning and not self.use_rope:
attn_scale = self._get_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
q = self._mul_attn_scale(positions=positions, q=q)

attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
Expand Down
Loading