[diffusion] use fa3 by default when device supports it#783
[diffusion] use fa3 by default when device supports it#783SamitHuang merged 11 commits intovllm-project:mainfrom
Conversation
Signed-off-by: zjy0516 <[email protected]>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ecd0d5378e
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if os.environ.get("DIFFUSION_ATTENTION_BACKEND") is not None: | ||
| backend_name = os.environ.get("DIFFUSION_ATTENTION_BACKEND") | ||
|
|
||
| if backend_name is not None: |
There was a problem hiding this comment.
Initialize backend_name before conditional selection
When neither DIFFUSION_ATTENTION_BACKEND is set nor the CUDA compute capability falls in the 8.x–9.x range, backend_name is never assigned, but it is still referenced in the if backend_name is not None: check. This raises UnboundLocalError and prevents the function from falling back to SDPA on CPU or older GPUs. Initialize backend_name to None (as in the previous version) before the conditional selection to avoid this runtime crash.
Useful? React with 👍 / 👎.
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import torch | ||
| from fa3_fwd_interface import flash_attn_func |
There was a problem hiding this comment.
will it lead to import error on NPUs?
There was a problem hiding this comment.
It should handled in selector.py, which means we will not import this file on npu
| backend_name: str | None = os.environ.get("DIFFUSION_ATTENTION_BACKEND") | ||
|
|
||
| if detect_device_type() == "cuda": | ||
| compute_capability = torch.cuda.get_device_capability() |
There was a problem hiding this comment.
I have seen a FA3 support issue in flash_attn on Blackwell devices, such as GB200 chips.
Do you think that we should keep both FA2 and FA3, and select different FA accordingly?
There was a problem hiding this comment.
I plan to keep FA2 and FA3. For Blackwell, fall back to sdpa or soething else.
For FA2, I don't have enough bandwidth to build and upload it recently. I'll do this later.
There was a problem hiding this comment.
FYI, I found FA3 also support sm80. So I don't plan to maintain fa2 now
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
|
it seems the ci failed due to acc |
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import torch | ||
| from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func |
There was a problem hiding this comment.
@ZJY0516 can we use this helper function from upstream now that we are at v0.14.0?
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
There was a problem hiding this comment.
But upstream don't provide flash_attn_func, so we still need fa3_fwd
Purpose
I have uploaded a fa3 package in pypi, which only has forward kernel.
Now users can automatically install it.
https://pypi.org/project/fa3-fwd/
https://github.com/ZJY0516/fa3-fwd
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)