Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/model_doc/gpt_oss.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ The abstract from the paper is the following:
*<INSERT PAPER ABSTRACT HERE>*

Tips:
- **Attention Sinks with Flex Attention**: When using flex attention, attention sinks require special handling. Unlike with standard attention implementations where sinks can be added directly to attention scores, flex attention `score_mod` function operates on individual score elements rather than the full attention matrix. Therefore, attention sinks renormalization have to be applied after the flex attention computations by renormalizing the outputs using the log-sum-exp (LSE) values returned by flex attention.


<INSERT TIPS ABOUT MODEL HERE>

Expand Down
29 changes: 23 additions & 6 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,9 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
if head_mask is not None:
score = score + head_mask[batch_idx][head_idx][0][0]
if s_aux is not None:
logits_max = torch.max(score, dim=-1, keepdim=True).values
sinks = torch.exp(s_aux - logits_max)
unnormalized_scores = torch.exp(score - logits_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
score = unnormalized_scores / normalizer
# Note: attention sinks cannot be correctly implemented in score_mod
# because it requires operating on the full attention matrix before softmax.
# ==> this is done after flex attention
return score

enable_gqa = True
Expand All @@ -293,6 +290,11 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
# On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
return_lse = query.device.type != "cpu"

# Validate that s_aux is not silently ignored
if not return_lse and s_aux is not None:
logger.warning_once("s_aux provided with return_lse=False - forcing return_lse=True to avoid silent failure")
return_lse = True

flex_attention_output = compile_friendly_flex_attention(
query,
key,
Expand All @@ -311,6 +313,21 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
if return_lse:
attention_output, lse = flex_attention_output # type: ignore[misc]
lse = lse.to(value.dtype)

if s_aux is not None:
# Apply attention sinks by renormalizing using LSE
batch_size, num_heads, seq_len_q, _ = attention_output.shape # batch, num_heads, seq_len, head_dim
sinks = s_aux.view(1, -1, 1, 1).expand(batch_size, num_heads, seq_len_q, 1)

# We need to compute the normalization that includes the sinks
# since log(sum(exp(scores))) = lse, exp(log(sum(exp(scores)))) = exp(lse)
# NB: log(sum(exp(scores)) + exp(sink)) = log(exp(lse) + exp(sink))
lse_expanded = lse.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
combined_lse = torch.logsumexp(torch.cat([lse_expanded, sinks], dim=-1), dim=-1, keepdim=True)

# Use new_norm / old_norm = exp(combined_lse - lse) to compute renorm and apply
renorm_factor = torch.exp(lse_expanded - combined_lse)
attention_output = attention_output * renorm_factor
else:
attention_output = flex_attention_output # type: ignore[assignment]
lse = None
Expand Down