diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 2ccbad24261c..48daf0072aa2 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -282,10 +282,10 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): # On CPU we must skip returning LSE due to a runtime issue; elsewhere, follow PyTorch API and return it return_lse = query.device.type != "cpu" - # Validate that s_aux is not silently ignored if not return_lse and s_aux is not None: - logger.warning_once("s_aux provided with return_lse=False - forcing return_lse=True to avoid silent failure") - return_lse = True + raise ValueError( + "Attention sinks cannot be run on CPU with flex attention. Please switch to a different device, e.g. CUDA" + ) flex_attention_output = compile_friendly_flex_attention( query,