diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8c707c2561af..aeb485193f72 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -362,6 +362,19 @@ def __init__( self.use_upstream_fa = True if current_platform.is_xpu(): self.use_upstream_fa = False + # Flash attention requires head_dim to be a multiple of 32 + # Fall back to TORCH_SDPA if the head dimension is incompatible + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + } and self.hidden_size_per_attention_head % 32 != 0: + logger.warning( + f"Flash attention backend requires head_dim to be a multiple of 32, " + f"but got {self.hidden_size_per_attention_head}. " + f"Falling back to TORCH_SDPA backend." + ) + self.attn_backend = AttentionBackendEnum.TORCH_SDPA + self.flash_attn_varlen_func = None self.is_flash_attn_backend = self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA,