-
Notifications
You must be signed in to change notification settings - Fork 822
[diffusion] use fa3 by default when device supports it #783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ecd0d53
71b3790
fdb2a26
2b00f26
5503ba6
0e7f256
0fb4c30
8032678
1050ab6
889426e
3211ebd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,12 +5,14 @@ | |
| import os | ||
| from functools import cache | ||
|
|
||
| import torch | ||
| from vllm.logger import init_logger | ||
|
|
||
| from vllm_omni.diffusion.attention.backends.abstract import ( | ||
| AttentionBackend, | ||
| ) | ||
| from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend | ||
| from vllm_omni.utils.platform_utils import detect_device_type, is_rocm | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -63,7 +65,22 @@ def get_attn_backend(head_size: int) -> type[AttentionBackend]: | |
| The selected attention backend class | ||
| """ | ||
| # Check environment variable | ||
| backend_name: str | None = os.environ.get("DIFFUSION_ATTENTION_BACKEND") | ||
|
|
||
| backend_name = os.environ.get("DIFFUSION_ATTENTION_BACKEND", None) | ||
|
|
||
| if detect_device_type() == "cuda" and not is_rocm(): | ||
| compute_capability = torch.cuda.get_device_capability() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have seen a FA3 support issue in flash_attn on Blackwell devices, such as GB200 chips.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I plan to keep FA2 and FA3. For Blackwell, fall back to sdpa or soething else. For FA2, I don't have enough bandwidth to build and upload it recently. I'll do this later.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, I found FA3 also support sm80. So I don't plan to maintain fa2 now |
||
| major, minor = compute_capability | ||
| if 80 <= major * 10 + minor < 100: | ||
| if backend_name is None: | ||
| backend_name = "FLASH_ATTN" | ||
| else: | ||
| if backend_name == "FLASH_ATTN": | ||
| logger.warning( | ||
| """Flash Attention requires GPU with compute capability >= 8.0 or < 10.0. " | ||
| "Falling back to TORCH_SDPA backend.""" | ||
| ) | ||
| backend_name = "TORCH_SDPA" | ||
|
|
||
| if backend_name is not None: | ||
| backend_name_upper = backend_name.upper() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZJY0516 can we use this helper function from upstream now that we are at v0.14.0?
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_funchttps://github.com/vllm-project/vllm/blob/c80f92c14d5e6c52691f586052af68d1495aac74/vllm/v1/attention/ops/vit_attn_wrappers.py#L38
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But upstream don't provide
flash_attn_func, so we still need fa3_fwd