Skip to content

[diffusion] use fa3 by default when device supports it#783

Merged
SamitHuang merged 11 commits intovllm-project:mainfrom
ZJY0516:fa
Jan 22, 2026
Merged

[diffusion] use fa3 by default when device supports it#783
SamitHuang merged 11 commits intovllm-project:mainfrom
ZJY0516:fa

Conversation

@ZJY0516
Copy link
Collaborator

@ZJY0516 ZJY0516 commented Jan 14, 2026

Purpose

I have uploaded a fa3 package in pypi, which only has forward kernel.

Now users can automatically install it.

https://pypi.org/project/fa3-fwd/
https://github.com/ZJY0516/fa3-fwd

Test Plan

Test Result

qwen_image_output
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: zjy0516 <[email protected]>
@ZJY0516 ZJY0516 requested a review from SamitHuang January 14, 2026 13:25
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ecd0d5378e

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 73 to 76
if os.environ.get("DIFFUSION_ATTENTION_BACKEND") is not None:
backend_name = os.environ.get("DIFFUSION_ATTENTION_BACKEND")

if backend_name is not None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Initialize backend_name before conditional selection

When neither DIFFUSION_ATTENTION_BACKEND is set nor the CUDA compute capability falls in the 8.x–9.x range, backend_name is never assigned, but it is still referenced in the if backend_name is not None: check. This raises UnboundLocalError and prevents the function from falling back to SDPA on CPU or older GPUs. Initialize backend_name to None (as in the previous version) before the conditional selection to avoid this runtime crash.

Useful? React with 👍 / 👎.

@ZJY0516 ZJY0516 added the ready label to trigger buildkite CI label Jan 15, 2026
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
@ZJY0516
Copy link
Collaborator Author

ZJY0516 commented Jan 15, 2026

@wtomin Is there anything I need to do to enable FA by default in parallel mode after #760 merged?

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
from fa3_fwd_interface import flash_attn_func
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will it lead to import error on NPUs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should handled in selector.py, which means we will not import this file on npu

@wtomin
Copy link
Contributor

wtomin commented Jan 15, 2026

@wtomin Is there anything I need to do to enable FA by default in parallel mode after #760 merged?

Can you make sure that flash_attn_varlen_func can be supported in FA3? #760 needs this function.

@ZJY0516
Copy link
Collaborator Author

ZJY0516 commented Jan 15, 2026

@wtomin Is there anything I need to do to enable FA by default in parallel mode after #760 merged?

Can you make sure that flash_attn_varlen_func can be supported in FA3? #760 needs this function.

yes

backend_name: str | None = os.environ.get("DIFFUSION_ATTENTION_BACKEND")

if detect_device_type() == "cuda":
compute_capability = torch.cuda.get_device_capability()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen a FA3 support issue in flash_attn on Blackwell devices, such as GB200 chips.
Do you think that we should keep both FA2 and FA3, and select different FA accordingly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to keep FA2 and FA3. For Blackwell, fall back to sdpa or soething else.

For FA2, I don't have enough bandwidth to build and upload it recently. I'll do this later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I found FA3 also support sm80. So I don't plan to maintain fa2 now

Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
@hsliuustc0106
Copy link
Collaborator

it seems the ci failed due to acc

Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
Signed-off-by: zjy0516 <[email protected]>
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZJY0516 can we use this helper function from upstream now that we are at v0.14.0?

from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func

https://github.com/vllm-project/vllm/blob/c80f92c14d5e6c52691f586052af68d1495aac74/vllm/v1/attention/ops/vit_attn_wrappers.py#L38

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But upstream don't provide flash_attn_func, so we still need fa3_fwd

@SamitHuang SamitHuang merged commit 0df8e80 into vllm-project:main Jan 22, 2026
7 checks passed
@ZJY0516 ZJY0516 deleted the fa branch January 23, 2026 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants