Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user_guide/diffusion/parallelism_acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The following table shows which models are currently supported by parallelism me
| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | ✅ |
| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ✅ (TP=2 only) |
| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ❌ | ❌ | ❌ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | | | ❌ | ✅ |
| **FLUX.2-klein** | `black-forest-labs/FLUX.2-klein-4B` | | | ❌ | ✅ |
| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ❌ | ❌ | ✅ |

!!! note "TP Limitations for Diffusion Models"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import AdaLayerNormContinuous
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand All @@ -39,8 +40,18 @@

from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.distributed.parallel_state import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group,
)
from vllm_omni.diffusion.distributed.sp_sharding import sp_shard_with_padding
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.diffusion.layers.rope import RotaryEmbedding

logger = init_logger(__name__)


class Flux2SwiGLU(nn.Module):
"""SwiGLU activation used by Flux2."""
Expand Down Expand Up @@ -334,6 +345,12 @@ def forward(


class Flux2SingleTransformerBlock(nn.Module):
"""
Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support.

SP handling is delegated to Flux2Attention via the forward context.
"""

def __init__(
self,
dim: int,
Expand Down Expand Up @@ -367,6 +384,13 @@ def forward(
split_hidden_states: bool = False,
text_seq_len: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for Flux2SingleTransformerBlock with SP support.

In SP mode: image hidden_states is chunked (B, img_len/SP, D),
text encoder_hidden_states is full (B, txt_len, D).
The block concatenates them for joint attention.
"""
if encoder_hidden_states is not None:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
Expand Down Expand Up @@ -556,6 +580,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor,
class Flux2Transformer2DModel(nn.Module):
"""
The Transformer model introduced in Flux 2.

Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
"""

_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
Expand All @@ -580,6 +606,7 @@ def __init__(
rope_theta: int = 2000,
eps: float = 1e-6,
guidance_embeds: bool = True,
od_config: OmniDiffusionConfig = None,
):
super().__init__()
self.out_channels = out_channels or in_channels
Expand All @@ -601,6 +628,13 @@ def __init__(
guidance_embeds=guidance_embeds,
)

if od_config is not None:
self.parallel_config = od_config.parallel_config
else:
from vllm_omni.diffusion.data import DiffusionParallelConfig

self.parallel_config = DiffusionParallelConfig()

self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope))
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels,
Expand Down Expand Up @@ -672,6 +706,26 @@ def forward(

num_txt_tokens = encoder_hidden_states.shape[1]

sp_size = self.parallel_config.sequence_parallel_size
get_forward_context().sequence_parallel_size = sp_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fetch the sp_size, using self.parallel_config.sequence_parallel_size would be sufficient. I don't see why you need to set get_forward_context().sequence_parallel_size here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that LongCatImageTransformer2DModel also edits get_forward_context...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

od_config = get_forward_context().omni_diffusion_config
parallel_config = od_config.parallel_config
sequence_parallel_size = parallel_config.sequence_parallel_size

This would be my recommendation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fetch the sp_size, using self.parallel_config.sequence_parallel_size would be sufficient. I don't see why you need to set get_forward_context().sequence_parallel_size here.

ths, I keep only sp_size = self.parallel_config.sequence_parallel_size—that matches what you suggested

sp_pad_size = 0
if sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
original_shape = hidden_states.shape
hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little concerned about the padding behavior here. It looks like flux2_klein does not support hidden_states_mask (like qwen_image), therefore, the padded tokens will participate in the attention computation, which is wrong.

If using Non-intrusive _sp_plan, sp_padding_size will be automatically set in get_forward_context(), however, hidden_states_mask is required to exclude the padded tokens.

# vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
        hidden_states_mask = None  # default
        if self.parallel_config is not None and self.parallel_config.sequence_parallel_size > 1:
            ctx = get_forward_context()
            if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
                # Create mask for the full (padded) sequence
                # valid positions = True, padding positions = False
                batch_size = hidden_states.shape[0]
                padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
                hidden_states_mask = torch.ones(
                    batch_size,
                    padded_seq_len,
                    dtype=torch.bool,
                    device=hidden_states.device,
                )
                hidden_states_mask[:, ctx.sp_original_seq_len :] = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RuixiangMa I would like to invite you to join the discussion here #1324.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two suggestions:

  1. refer to [Feature] Support Wan2.2 output with irregular shapes #1279 , you may use auto_pad=True and use hidden_states_mask to exclude the padded tokens.
  2. You may set SequenceParallelInput(auto_pad=False), this will raise an error when seq_len not divisible by sp_size, which I think might be rare for Flux.2-klein. We can take care of rare cases in the future, after we make our discussion clear in [RFC]: Ulysses-SP Constraints Solution #1324.

get_forward_context().split_text_embed_in_sp = False
if not hasattr(self, "_sp_forward_logged"):
self._sp_forward_logged = True
logger.info(
f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, "
f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}"
)
else:
if not hasattr(self, "_sp_forward_logged"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self._sp_forward_logged used for debugging only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self._sp_forward_logged used for debugging only

yes, I'm removing it now.

self._sp_forward_logged = True
logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}")

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
Expand All @@ -690,11 +744,21 @@ def forward(
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]

image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids)
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids)

if sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
img_freqs_cos, _ = sp_shard_with_padding(img_freqs_cos, dim=0)
img_freqs_sin, _ = sp_shard_with_padding(img_freqs_sin, dim=0)
if get_forward_context().split_text_embed_in_sp:
txt_freqs_cos, _ = sp_shard_with_padding(txt_freqs_cos, dim=0)
txt_freqs_sin, _ = sp_shard_with_padding(txt_freqs_sin, dim=0)

concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
torch.cat([txt_freqs_cos, img_freqs_cos], dim=0),
torch.cat([txt_freqs_sin, img_freqs_sin], dim=0),
)

for block in self.transformer_blocks:
Expand Down Expand Up @@ -722,6 +786,11 @@ def forward(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

if self.parallel_config.sequence_parallel_size > 1:
output = get_sp_group().all_gather(output, dim=1)
if sp_pad_size > 0:
output = output[:, :-sp_pad_size, ...]

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(
).to(self._execution_device)

transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel)
self.transformer = Flux2Transformer2DModel(**transformer_kwargs)
self.transformer = Flux2Transformer2DModel(od_config=od_config, **transformer_kwargs)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
Expand Down