Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"cache-dit==1.2.0",
"tqdm>=4.66.0",
"torchsde>=0.2.6", # Required for Stable Audio scheduler
"fa3-fwd", # flash attention 3, maintained by @ZJY0516
"openai-whisper>=20250625",
# "vllm==0.14.0", # TODO: fix the entrypoints overwrite problem
]
Expand Down
3 changes: 2 additions & 1 deletion tests/e2e/offline_inference/test_ovis_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def test_real_transformer_init_and_forward():
}
)

od_config = OmniDiffusionConfig(model="dummy-ovis", tf_model_config=tf_config, dtype=torch.float32, num_gpus=1)
od_config = OmniDiffusionConfig(model="dummy-ovis", tf_model_config=tf_config, dtype=torch.bfloat16, num_gpus=1)
torch.set_default_dtype(torch.bfloat16)

# Mock distributed state for QKVParallelLinear initialization
# We patch get_tp_group because get_tensor_model_parallel_rank calls it and asserts _TP is not None
Expand Down
13 changes: 1 addition & 12 deletions vllm_omni/diffusion/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch
from fa3_fwd_interface import flash_attn_func, flash_attn_varlen_func
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZJY0516 can we use this helper function from upstream now that we are at v0.14.0?

from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func

https://github.com/vllm-project/vllm/blob/c80f92c14d5e6c52691f586052af68d1495aac74/vllm/v1/attention/ops/vit_attn_wrappers.py#L38

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But upstream don't provide flash_attn_func, so we still need fa3_fwd

from vllm.logger import init_logger

from vllm_omni.diffusion.attention.backends.abstract import (
Expand All @@ -13,18 +14,6 @@

logger = init_logger(__name__)

try:
# only tested with flash_attn v3
# from flash_attn_interface import flash_attn_func as flash_attn_3_func # not available in flash-attn 2.8.1
from flash_attn import flash_attn_func, flash_attn_varlen_func # can be FA2 or FA3
except ImportError:
logger.warning(
"FlashAttentionBackend is not available. You may install flash-attn "
"by running `uv pip install flash-attn==2.8.1 --no-build-isolation`"
" or install pre-built flash-attn from https://github.com/Dao-AILab/flash-attention/releases"
)
raise ImportError


class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
Expand Down
19 changes: 18 additions & 1 deletion vllm_omni/diffusion/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import os
from functools import cache

import torch
from vllm.logger import init_logger

from vllm_omni.diffusion.attention.backends.abstract import (
AttentionBackend,
)
from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend
from vllm_omni.utils.platform_utils import detect_device_type, is_rocm

logger = init_logger(__name__)

Expand Down Expand Up @@ -63,7 +65,22 @@ def get_attn_backend(head_size: int) -> type[AttentionBackend]:
The selected attention backend class
"""
# Check environment variable
backend_name: str | None = os.environ.get("DIFFUSION_ATTENTION_BACKEND")

backend_name = os.environ.get("DIFFUSION_ATTENTION_BACKEND", None)

if detect_device_type() == "cuda" and not is_rocm():
compute_capability = torch.cuda.get_device_capability()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen a FA3 support issue in flash_attn on Blackwell devices, such as GB200 chips.
Do you think that we should keep both FA2 and FA3, and select different FA accordingly?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to keep FA2 and FA3. For Blackwell, fall back to sdpa or soething else.

For FA2, I don't have enough bandwidth to build and upload it recently. I'll do this later.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I found FA3 also support sm80. So I don't plan to maintain fa2 now

major, minor = compute_capability
if 80 <= major * 10 + minor < 100:
if backend_name is None:
backend_name = "FLASH_ATTN"
else:
if backend_name == "FLASH_ATTN":
logger.warning(
"""Flash Attention requires GPU with compute capability >= 8.0 or < 10.0. "
"Falling back to TORCH_SDPA backend."""
)
backend_name = "TORCH_SDPA"

if backend_name is not None:
backend_name_upper = backend_name.upper()
Expand Down