Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
enable_gqa = False

kernel_options = kwargs.get("kernel_options")
# dynamically set return_lse to True if output_attentions is True
return_lse = module.config.output_attentions

attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
Expand All @@ -301,7 +304,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
return_lse=return_lse,
training=module.training,
)
# lse is returned in float32
Expand Down