Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.normalization import AdaLayerNormContinuous
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand All @@ -39,7 +40,17 @@

from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.distributed.parallel_state import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group,
)
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
from vllm_omni.platforms import current_omni_platform

logger = init_logger(__name__)


class Flux2SwiGLU(nn.Module):
Expand Down Expand Up @@ -334,6 +345,12 @@ def forward(


class Flux2SingleTransformerBlock(nn.Module):
"""
Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support.

SP handling is delegated to Flux2Attention via the forward context.
"""

def __init__(
self,
dim: int,
Expand Down Expand Up @@ -367,6 +384,13 @@ def forward(
split_hidden_states: bool = False,
text_seq_len: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for Flux2SingleTransformerBlock with SP support.

In SP mode: image hidden_states is chunked (B, img_len/SP, D),
text encoder_hidden_states is full (B, txt_len, D).
The block concatenates them for joint attention.
"""
if encoder_hidden_states is not None:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
Expand Down Expand Up @@ -556,6 +580,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor,
class Flux2Transformer2DModel(nn.Module):
"""
The Transformer model introduced in Flux 2.

Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig.
"""

_repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
Expand All @@ -580,6 +606,7 @@ def __init__(
rope_theta: int = 2000,
eps: float = 1e-6,
guidance_embeds: bool = True,
od_config: OmniDiffusionConfig = None,
):
super().__init__()
self.out_channels = out_channels or in_channels
Expand All @@ -601,6 +628,13 @@ def __init__(
guidance_embeds=guidance_embeds,
)

if od_config is not None:
self.parallel_config = od_config.parallel_config
else:
from vllm_omni.diffusion.data import DiffusionParallelConfig

self.parallel_config = DiffusionParallelConfig()

self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope))
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
in_channels=timestep_guidance_channels,
Expand Down Expand Up @@ -672,6 +706,25 @@ def forward(

num_txt_tokens = encoder_hidden_states.shape[1]

sp_size = self.parallel_config.sequence_parallel_size
get_forward_context().sequence_parallel_size = sp_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fetch the sp_size, using self.parallel_config.sequence_parallel_size would be sufficient. I don't see why you need to set get_forward_context().sequence_parallel_size here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that LongCatImageTransformer2DModel also edits get_forward_context...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

od_config = get_forward_context().omni_diffusion_config
parallel_config = od_config.parallel_config
sequence_parallel_size = parallel_config.sequence_parallel_size

This would be my recommendation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fetch the sp_size, using self.parallel_config.sequence_parallel_size would be sufficient. I don't see why you need to set get_forward_context().sequence_parallel_size here.

ths, I keep only sp_size = self.parallel_config.sequence_parallel_size—that matches what you suggested

if sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
original_shape = hidden_states.shape
hidden_states = torch.chunk(hidden_states, sp_world_size, dim=1)[sp_rank]
get_forward_context().split_text_embed_in_sp = False
if not hasattr(self, "_sp_forward_logged"):
self._sp_forward_logged = True
logger.info(
f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, "
f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}"
)
else:
if not hasattr(self, "_sp_forward_logged"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self._sp_forward_logged used for debugging only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self._sp_forward_logged used for debugging only

yes, I'm removing it now.

self._sp_forward_logged = True
logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}")

timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
Expand All @@ -690,11 +743,27 @@ def forward(
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]

image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
if current_omni_platform.is_npu():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gcanlin do we have better ways to handle this difference? this is so awkward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wtomin PTAL

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that there exists npu and mps hardcode in Flux2PosEmbed, which may be from diffusers library I guess. I will take a micro-refactoring PR for removing it. We could delete this npu branch temporarily if we'd like to merge this PR first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that there exists npu and mps hardcode in Flux2PosEmbed, which may be from diffusers library I guess. I will take a micro-refactoring PR for removing it. We could delete this npu branch temporarily if we'd like to merge this PR first.

Got it! I'll just remove this branch right now

Copy link
Collaborator

@ZJY0516 ZJY0516 Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember some of these is because torch_npu doesn't support complex number

img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids.cpu())
img_freqs_cos, img_freqs_sin = img_freqs_cos.npu(), img_freqs_sin.npu()
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids.cpu())
txt_freqs_cos, txt_freqs_sin = txt_freqs_cos.npu(), txt_freqs_sin.npu()
else:
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids)
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids)

if sp_size > 1:
sp_world_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
img_freqs_cos = torch.chunk(img_freqs_cos, sp_world_size, dim=0)[sp_rank]
img_freqs_sin = torch.chunk(img_freqs_sin, sp_world_size, dim=0)[sp_rank]
if get_forward_context().split_text_embed_in_sp:
txt_freqs_cos = torch.chunk(txt_freqs_cos, sp_world_size, dim=0)[sp_rank]
txt_freqs_sin = torch.chunk(txt_freqs_sin, sp_world_size, dim=0)[sp_rank]

concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
torch.cat([txt_freqs_cos, img_freqs_cos], dim=0),
torch.cat([txt_freqs_sin, img_freqs_sin], dim=0),
)

for block in self.transformer_blocks:
Expand Down Expand Up @@ -722,6 +791,9 @@ def forward(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

if self.parallel_config.sequence_parallel_size > 1:
output = get_sp_group().all_gather(output, dim=1)

if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __init__(
).to(self._execution_device)

transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel)
self.transformer = Flux2Transformer2DModel(**transformer_kwargs)
self.transformer = Flux2Transformer2DModel(od_config=od_config, **transformer_kwargs)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
Expand Down