Skip to content

Commit 061c6b7

Browse files
SamuelBarryCSArthurZucker
authored andcommitted
Fix attention sink implementation in flex attention (huggingface#41083)
* Fix attention sink implementation in flex attention * fix dim * fix * Remove print * raisae error when return_lse is False yet s_aux is providewd * Clean test files for merge * Update src/transformers/integrations/flex_attention.py Co-authored-by: Arthur <[email protected]> * force return lse * Add to doc --------- Co-authored-by: Arthur <[email protected]>
1 parent a9b4e25 commit 061c6b7

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

docs/source/en/model_doc/gpt_oss.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ The abstract from the paper is the following:
3535
*<INSERT PAPER ABSTRACT HERE>*
3636

3737
Tips:
38+
- **Attention Sinks with Flex Attention**: When using flex attention, attention sinks require special handling. Unlike with standard attention implementations where sinks can be added directly to attention scores, flex attention `score_mod` function operates on individual score elements rather than the full attention matrix. Therefore, attention sinks renormalization have to be applied after the flex attention computations by renormalizing the outputs using the log-sum-exp (LSE) values returned by flex attention.
39+
3840

3941
<INSERT TIPS ABOUT MODEL HERE>
4042

src/transformers/integrations/flex_attention.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,9 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
272272
score = score + score_mask[batch_idx][0][q_idx][kv_idx]
273273
if head_mask is not None:
274274
score = score + head_mask[batch_idx][head_idx][0][0]
275-
if s_aux is not None:
276-
logits_max = torch.max(score, dim=-1, keepdim=True).values
277-
sinks = torch.exp(s_aux - logits_max)
278-
unnormalized_scores = torch.exp(score - logits_max)
279-
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
280-
score = unnormalized_scores / normalizer
275+
# Note: attention sinks cannot be correctly implemented in score_mod
276+
# because it requires operating on the full attention matrix before softmax.
277+
# ==> this is done after flex attention
281278
return score
282279

283280
enable_gqa = True
@@ -293,6 +290,11 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
293290
# On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
294291
return_lse = query.device.type != "cpu"
295292

293+
# Validate that s_aux is not silently ignored
294+
if not return_lse and s_aux is not None:
295+
logger.warning_once("s_aux provided with return_lse=False - forcing return_lse=True to avoid silent failure")
296+
return_lse = True
297+
296298
flex_attention_output = compile_friendly_flex_attention(
297299
query,
298300
key,
@@ -311,6 +313,21 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
311313
if return_lse:
312314
attention_output, lse = flex_attention_output # type: ignore[misc]
313315
lse = lse.to(value.dtype)
316+
317+
if s_aux is not None:
318+
# Apply attention sinks by renormalizing using LSE
319+
batch_size, num_heads, seq_len_q, _ = attention_output.shape # batch, num_heads, seq_len, head_dim
320+
sinks = s_aux.view(1, -1, 1, 1).expand(batch_size, num_heads, seq_len_q, 1)
321+
322+
# We need to compute the normalization that includes the sinks
323+
# since log(sum(exp(scores))) = lse, exp(log(sum(exp(scores)))) = exp(lse)
324+
# NB: log(sum(exp(scores)) + exp(sink)) = log(exp(lse) + exp(sink))
325+
lse_expanded = lse.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
326+
combined_lse = torch.logsumexp(torch.cat([lse_expanded, sinks], dim=-1), dim=-1, keepdim=True)
327+
328+
# Use new_norm / old_norm = exp(combined_lse - lse) to compute renorm and apply
329+
renorm_factor = torch.exp(lse_expanded - combined_lse)
330+
attention_output = attention_output * renorm_factor
314331
else:
315332
attention_output = flex_attention_output # type: ignore[assignment]
316333
lse = None

0 commit comments

Comments
 (0)