diff --git a/docs/source/en/model_doc/gpt_oss.md b/docs/source/en/model_doc/gpt_oss.md index 47c970eb17e6..60741d8473fa 100644 --- a/docs/source/en/model_doc/gpt_oss.md +++ b/docs/source/en/model_doc/gpt_oss.md @@ -35,6 +35,8 @@ The abstract from the paper is the following: ** Tips: +- **Attention Sinks with Flex Attention**: When using flex attention, attention sinks require special handling. Unlike with standard attention implementations where sinks can be added directly to attention scores, flex attention `score_mod` function operates on individual score elements rather than the full attention matrix. Therefore, attention sinks renormalization have to be applied after the flex attention computations by renormalizing the outputs using the log-sum-exp (LSE) values returned by flex attention. + diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 85ddc433e67a..15ab20c600f2 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -272,12 +272,9 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): score = score + score_mask[batch_idx][0][q_idx][kv_idx] if head_mask is not None: score = score + head_mask[batch_idx][head_idx][0][0] - if s_aux is not None: - logits_max = torch.max(score, dim=-1, keepdim=True).values - sinks = torch.exp(s_aux - logits_max) - unnormalized_scores = torch.exp(score - logits_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - score = unnormalized_scores / normalizer + # Note: attention sinks cannot be correctly implemented in score_mod + # because it requires operating on the full attention matrix before softmax. + # ==> this is done after flex attention return score enable_gqa = True @@ -293,6 +290,11 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): # On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it return_lse = query.device.type != "cpu" + # Validate that s_aux is not silently ignored + if not return_lse and s_aux is not None: + logger.warning_once("s_aux provided with return_lse=False - forcing return_lse=True to avoid silent failure") + return_lse = True + flex_attention_output = compile_friendly_flex_attention( query, key, @@ -311,6 +313,21 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): if return_lse: attention_output, lse = flex_attention_output # type: ignore[misc] lse = lse.to(value.dtype) + + if s_aux is not None: + # Apply attention sinks by renormalizing using LSE + batch_size, num_heads, seq_len_q, _ = attention_output.shape # batch, num_heads, seq_len, head_dim + sinks = s_aux.view(1, -1, 1, 1).expand(batch_size, num_heads, seq_len_q, 1) + + # We need to compute the normalization that includes the sinks + # since log(sum(exp(scores))) = lse, exp(log(sum(exp(scores)))) = exp(lse) + # NB: log(sum(exp(scores)) + exp(sink)) = log(exp(lse) + exp(sink)) + lse_expanded = lse.unsqueeze(-1) # [batch, num_heads, seq_len, 1] + combined_lse = torch.logsumexp(torch.cat([lse_expanded, sinks], dim=-1), dim=-1, keepdim=True) + + # Use new_norm / old_norm = exp(combined_lse - lse) to compute renorm and apply + renorm_factor = torch.exp(lse_expanded - combined_lse) + attention_output = attention_output * renorm_factor else: attention_output = flex_attention_output # type: ignore[assignment] lse = None