Skip to content

Commit f37738f

Browse files
committed
raisae error when return_lse is False yet s_aux is providewd
1 parent 6f2e4c9 commit f37738f

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/transformers/integrations/flex_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,13 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
290290
# On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it
291291
return_lse = query.device.type != "cpu"
292292

293+
# Validate that s_aux is not silently ignored
294+
if not return_lse and s_aux is not None:
295+
raise ValueError(
296+
"s_aux is not supported when return_lse=False (e.g., on CPU). "
297+
"Attention sinks require LSE computation which is not available on this device."
298+
)
299+
293300
flex_attention_output = compile_friendly_flex_attention(
294301
query,
295302
key,

0 commit comments

Comments
 (0)