-
-
Notifications
You must be signed in to change notification settings - Fork 13k
[BugFix] Cosmos-Reason1-7B Model Flash Attention requires head dim to be a multiple of 32 #29615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -362,6 +362,19 @@ | |
| self.use_upstream_fa = True | ||
| if current_platform.is_xpu(): | ||
| self.use_upstream_fa = False | ||
| # Flash attention requires head_dim to be a multiple of 32 | ||
| # Fall back to TORCH_SDPA if the head dimension is incompatible | ||
| if self.attn_backend in { | ||
| AttentionBackendEnum.FLASH_ATTN, | ||
| AttentionBackendEnum.ROCM_AITER_FA, | ||
| } and self.hidden_size_per_attention_head % 32 != 0: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think #28763 has added
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I have verify this bug does not occur in latest nightly build, verify to close this PR and issue. |
||
| logger.warning( | ||
| f"Flash attention backend requires head_dim to be a multiple of 32, " | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please fix pre-commit |
||
| f"but got {self.hidden_size_per_attention_head}. " | ||
| f"Falling back to TORCH_SDPA backend." | ||
| ) | ||
| self.attn_backend = AttentionBackendEnum.TORCH_SDPA | ||
| self.flash_attn_varlen_func = None | ||
| self.is_flash_attn_backend = self.attn_backend in { | ||
| AttentionBackendEnum.FLASH_ATTN, | ||
| AttentionBackendEnum.ROCM_AITER_FA, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ECMGit can you skip this check on
rocmfor now? We have different conditions.