Skip to content
Open
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 219 additions & 43 deletions vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import AdaLayerNormContinuous
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand All @@ -39,8 +40,15 @@

from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.distributed.sp_plan import (
SequenceParallelInput,
SequenceParallelOutput,
)
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.diffusion.layers.rope import RotaryEmbedding

logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig

Expand Down Expand Up @@ -208,33 +216,84 @@ def forward(
encoder_query = self.norm_added_q(encoder_query)
encoder_key = self.norm_added_k(encoder_key)

query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)

if image_rotary_emb is not None:
cos, sin = image_rotary_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)

attn_metadata = None
if attention_mask is not None:
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_metadata = AttentionMetadata(attn_mask=attention_mask)

hidden_states = self.attn(query, key, value, attn_metadata)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)

if encoder_hidden_states is not None:
context_len = encoder_hidden_states.shape[1]
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[context_len, hidden_states.shape[1] - context_len],
dim=1,
)
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
forward_ctx = get_forward_context()
use_sp_joint_attention = forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp

if use_sp_joint_attention and image_rotary_emb is not None:
cos, sin = image_rotary_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
txt_len = encoder_query.shape[1]
txt_cos, img_cos = cos[:txt_len], cos[txt_len:]
txt_sin, img_sin = sin[:txt_len], sin[txt_len:]
query = self.rope(query, img_cos, img_sin)
key = self.rope(key, img_cos, img_sin)
encoder_query = self.rope(encoder_query, txt_cos, txt_sin)
encoder_key = self.rope(encoder_key, txt_cos, txt_sin)

attn_metadata = AttentionMetadata(
joint_query=encoder_query,
joint_key=encoder_key,
joint_value=encoder_value,
joint_strategy="front",
)
if attention_mask is not None:
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_metadata.attn_mask = attention_mask

hidden_states = self.attn(query, key, value, attn_metadata)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)

txt_len = encoder_hidden_states.shape[1]
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[txt_len, hidden_states.shape[1] - txt_len],
dim=1,
)
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
else:
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)

if image_rotary_emb is not None:
cos, sin = image_rotary_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)

attn_metadata = None
if attention_mask is not None:
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_metadata = AttentionMetadata(attn_mask=attention_mask)

hidden_states = self.attn(query, key, value, attn_metadata)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)

context_len = encoder_hidden_states.shape[1]
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[context_len, hidden_states.shape[1] - context_len],
dim=1,
)
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
else:
if image_rotary_emb is not None:
cos, sin = image_rotary_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)

attn_metadata = None
if attention_mask is not None:
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_metadata = AttentionMetadata(attn_mask=attention_mask)

hidden_states = self.attn(query, key, value, attn_metadata)
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)

hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
Expand Down Expand Up @@ -325,20 +384,59 @@ def forward(
query = self.norm_q(query)
key = self.norm_k(key)

if image_rotary_emb is not None:
forward_ctx = get_forward_context()
text_seq_len = kwargs.get("text_seq_len", None)
use_sp_single_stream = (
forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None
)

if use_sp_single_stream and image_rotary_emb is not None:
cos, sin = image_rotary_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)
txt_cos, img_cos = cos[:text_seq_len], cos[text_seq_len:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these added lines from 397 to 410 independent of any diffusion model and can be extracted?

Copy link
Contributor Author

@RuixiangMa RuixiangMa Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've extracted the RoPE logic into a reusable helper function(only verified flux1, flux2, and z-image so far, but the others (e.g., ovis-image) should be reusable as well.), and SP-specific logic in flux2_klein (lines 393-410) remains unchanged as it handles model-specific sequence parallelism splitting.

txt_sin, img_sin = sin[:text_seq_len], sin[text_seq_len:]

img_query = query[:, text_seq_len:]
img_key = key[:, text_seq_len:]
img_value = value[:, text_seq_len:]
text_query = query[:, :text_seq_len]
text_key = key[:, :text_seq_len]
text_value = value[:, :text_seq_len]

img_query = self.rope(img_query, img_cos, img_sin)
img_key = self.rope(img_key, img_cos, img_sin)
text_query = self.rope(text_query, txt_cos, txt_sin)
text_key = self.rope(text_key, txt_cos, txt_sin)

attn_metadata = AttentionMetadata(
joint_query=text_query,
joint_key=text_key,
joint_value=text_value,
joint_strategy="front",
)
if attention_mask is not None:
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_metadata.attn_mask = attention_mask

attn_output = self.attn(img_query, img_key, img_value, attn_metadata)
else:
if image_rotary_emb is not None:
cos, sin = image_rotary_emb
cos = cos.to(query.dtype)
sin = sin.to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)

attn_metadata = None
if attention_mask is not None:
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_metadata = AttentionMetadata(attn_mask=attention_mask)

attn_metadata = None
if attention_mask is not None:
if attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
attn_metadata = AttentionMetadata(attn_mask=attention_mask)
attn_output = self.attn(query, key, value, attn_metadata)

attn_output = self.attn(query, key, value, attn_metadata)
attn_output = attn_output.flatten(2, 3).to(query.dtype)

mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states)
Expand Down Expand Up @@ -383,6 +481,13 @@ def forward(
split_hidden_states: bool = False,
text_seq_len: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for Flux2SingleTransformerBlock with SP support.

In SP mode: image hidden_states is chunked (B, img_len/SP, D),
text encoder_hidden_states is full (B, txt_len, D).
The block concatenates them for joint attention.
"""
if encoder_hidden_states is not None:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
Expand Down Expand Up @@ -525,6 +630,32 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return freqs_cos, freqs_sin


class Flux2RopePrepare(nn.Module):
"""Prepares hidden_states and RoPE embeddings for sequence parallel.

This module encapsulates the input projection and RoPE computation for Flux.2-klein.
The key insight is that hidden_states and img_freqs must be sharded together
to maintain dimension alignment for RoPE computation in attention layers.
txt_freqs is kept replicated for dual-stream joint attention.
"""

def __init__(self, x_embedder: nn.Linear, pos_embed: Flux2PosEmbed):
super().__init__()
self.x_embedder = x_embedder
self.pos_embed = pos_embed

def forward(
self,
hidden_states: torch.Tensor,
img_ids: torch.Tensor,
txt_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self.x_embedder(hidden_states)
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids)
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids)
return hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin


class Flux2TimestepGuidanceEmbeddings(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -580,6 +711,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor,
class Flux2Transformer2DModel(nn.Module):
"""
The Transformer model introduced in Flux 2.

Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
"""

_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
Expand All @@ -588,6 +721,16 @@ class Flux2Transformer2DModel(nn.Module):
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
}

_sp_plan = {
"rope_prepare": {
0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True),
3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
4: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
},
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
"""SP plan: shard hidden_states/img_freqs at rope_prepare, gather output at proj_out."""

def __init__(
self,
patch_size: int = 1,
Expand All @@ -604,6 +747,7 @@ def __init__(
rope_theta: int = 2000,
eps: float = 1e-6,
guidance_embeds: bool = True,
od_config: OmniDiffusionConfig = None,
quant_config: "QuantizationConfig | None" = None,
):
super().__init__()
Expand All @@ -626,6 +770,13 @@ def __init__(
guidance_embeds=guidance_embeds,
)

if od_config is not None:
self.parallel_config = od_config.parallel_config
else:
from vllm_omni.diffusion.data import DiffusionParallelConfig

self.parallel_config = DiffusionParallelConfig()

self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope))
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels,
Expand All @@ -641,6 +792,8 @@ def __init__(
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)

self.rope_prepare = Flux2RopePrepare(self.x_embedder, self.pos_embed)

self.transformer_blocks = nn.ModuleList(
[
Flux2TransformerBlock(
Expand Down Expand Up @@ -699,6 +852,8 @@ def forward(

num_txt_tokens = encoder_hidden_states.shape[1]

get_forward_context().split_text_embed_in_sp = False

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
Expand All @@ -709,21 +864,41 @@ def forward(
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb)[0]

hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

if img_ids.ndim == 3:
img_ids = img_ids[0]
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]

image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin = self.rope_prepare(
hidden_states, img_ids, txt_ids
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
torch.cat([txt_freqs_cos, img_freqs_cos], dim=0),
torch.cat([txt_freqs_sin, img_freqs_sin], dim=0),
)

hidden_states_mask = None
ctx = get_forward_context()
if ctx.sp_active:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ctx.sp_active is should not be exposed to users or developers.

In _sp_plan hooks design, ctx.sp_active is determined by _sp_shard_depth. _sp_shard_depth will be automatically handled by hook function. Therefore, no need to include ctx.sp_active here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ctx.sp_active is should not be exposed to users or developers.

In _sp_plan hooks design, ctx.sp_active is determined by _sp_shard_depth. _sp_shard_depth will be automatically handled by hook function. Therefore, no need to include ctx.sp_active here.

It was redundant, I removed it

if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
batch_size = hidden_states.shape[0]
img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
full_seq_len = num_txt_tokens + img_padded_seq_len
hidden_states_mask = torch.ones(
batch_size,
full_seq_len,
dtype=torch.bool,
device=hidden_states.device,
)
hidden_states_mask[:, num_txt_tokens + ctx.sp_original_seq_len :] = False
if hidden_states_mask.all():
hidden_states_mask = None

if hidden_states_mask is not None:
joint_attention_kwargs["attention_mask"] = hidden_states_mask

for block in self.transformer_blocks:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
Expand All @@ -743,6 +918,7 @@ def forward(
temb_mod_params=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
text_seq_len=num_txt_tokens,
)

hidden_states = hidden_states[:, num_txt_tokens:, ...]
Expand Down