-
Notifications
You must be signed in to change notification settings - Fork 518
[Feat] support SP for FLUX.2-klein #1250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
e8fb739
f2daa52
225ec8a
c57c866
b33ce10
7166799
780ed85
8a7e687
63c4658
219d4f7
876d492
ec5f9f2
6a1e5f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| ) | ||
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | ||
| from diffusers.models.normalization import AdaLayerNormContinuous | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
|
|
@@ -39,8 +40,15 @@ | |
|
|
||
| from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata | ||
| from vllm_omni.diffusion.attention.layer import Attention | ||
| from vllm_omni.diffusion.data import OmniDiffusionConfig | ||
| from vllm_omni.diffusion.distributed.sp_plan import ( | ||
| SequenceParallelInput, | ||
| SequenceParallelOutput, | ||
| ) | ||
| from vllm_omni.diffusion.forward_context import get_forward_context | ||
| from vllm_omni.diffusion.layers.rope import RotaryEmbedding | ||
|
|
||
| logger = init_logger(__name__) | ||
| if TYPE_CHECKING: | ||
| from vllm.model_executor.layers.quantization.base_config import QuantizationConfig | ||
|
|
||
|
|
@@ -208,33 +216,84 @@ def forward( | |
| encoder_query = self.norm_added_q(encoder_query) | ||
| encoder_key = self.norm_added_k(encoder_key) | ||
|
|
||
| query = torch.cat([encoder_query, query], dim=1) | ||
| key = torch.cat([encoder_key, key], dim=1) | ||
| value = torch.cat([encoder_value, value], dim=1) | ||
|
|
||
| if image_rotary_emb is not None: | ||
| cos, sin = image_rotary_emb | ||
| cos = cos.to(query.dtype) | ||
| sin = sin.to(query.dtype) | ||
| query = self.rope(query, cos, sin) | ||
| key = self.rope(key, cos, sin) | ||
|
|
||
| attn_metadata = None | ||
| if attention_mask is not None: | ||
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.unsqueeze(1) | ||
| attn_metadata = AttentionMetadata(attn_mask=attention_mask) | ||
|
|
||
| hidden_states = self.attn(query, key, value, attn_metadata) | ||
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) | ||
|
|
||
| if encoder_hidden_states is not None: | ||
| context_len = encoder_hidden_states.shape[1] | ||
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( | ||
| [context_len, hidden_states.shape[1] - context_len], | ||
| dim=1, | ||
| ) | ||
| encoder_hidden_states = self.to_add_out(encoder_hidden_states) | ||
| forward_ctx = get_forward_context() | ||
| use_sp_joint_attention = forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp | ||
|
|
||
| if use_sp_joint_attention and image_rotary_emb is not None: | ||
| cos, sin = image_rotary_emb | ||
| cos = cos.to(query.dtype) | ||
| sin = sin.to(query.dtype) | ||
| txt_len = encoder_query.shape[1] | ||
| txt_cos, img_cos = cos[:txt_len], cos[txt_len:] | ||
| txt_sin, img_sin = sin[:txt_len], sin[txt_len:] | ||
| query = self.rope(query, img_cos, img_sin) | ||
| key = self.rope(key, img_cos, img_sin) | ||
| encoder_query = self.rope(encoder_query, txt_cos, txt_sin) | ||
| encoder_key = self.rope(encoder_key, txt_cos, txt_sin) | ||
|
|
||
| attn_metadata = AttentionMetadata( | ||
| joint_query=encoder_query, | ||
| joint_key=encoder_key, | ||
| joint_value=encoder_value, | ||
| joint_strategy="front", | ||
| ) | ||
| if attention_mask is not None: | ||
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.unsqueeze(1) | ||
| attn_metadata.attn_mask = attention_mask | ||
|
|
||
| hidden_states = self.attn(query, key, value, attn_metadata) | ||
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) | ||
|
|
||
| txt_len = encoder_hidden_states.shape[1] | ||
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( | ||
| [txt_len, hidden_states.shape[1] - txt_len], | ||
| dim=1, | ||
| ) | ||
| encoder_hidden_states = self.to_add_out(encoder_hidden_states) | ||
| else: | ||
| query = torch.cat([encoder_query, query], dim=1) | ||
| key = torch.cat([encoder_key, key], dim=1) | ||
| value = torch.cat([encoder_value, value], dim=1) | ||
|
|
||
| if image_rotary_emb is not None: | ||
| cos, sin = image_rotary_emb | ||
| cos = cos.to(query.dtype) | ||
| sin = sin.to(query.dtype) | ||
| query = self.rope(query, cos, sin) | ||
| key = self.rope(key, cos, sin) | ||
|
|
||
| attn_metadata = None | ||
| if attention_mask is not None: | ||
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.unsqueeze(1) | ||
| attn_metadata = AttentionMetadata(attn_mask=attention_mask) | ||
|
|
||
| hidden_states = self.attn(query, key, value, attn_metadata) | ||
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) | ||
|
|
||
| context_len = encoder_hidden_states.shape[1] | ||
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( | ||
| [context_len, hidden_states.shape[1] - context_len], | ||
| dim=1, | ||
| ) | ||
| encoder_hidden_states = self.to_add_out(encoder_hidden_states) | ||
| else: | ||
| if image_rotary_emb is not None: | ||
| cos, sin = image_rotary_emb | ||
| cos = cos.to(query.dtype) | ||
| sin = sin.to(query.dtype) | ||
| query = self.rope(query, cos, sin) | ||
| key = self.rope(key, cos, sin) | ||
|
|
||
| attn_metadata = None | ||
| if attention_mask is not None: | ||
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.unsqueeze(1) | ||
| attn_metadata = AttentionMetadata(attn_mask=attention_mask) | ||
|
|
||
| hidden_states = self.attn(query, key, value, attn_metadata) | ||
| hidden_states = hidden_states.flatten(2, 3).to(query.dtype) | ||
|
|
||
| hidden_states = self.to_out[0](hidden_states) | ||
| hidden_states = self.to_out[1](hidden_states) | ||
|
|
@@ -325,20 +384,59 @@ def forward( | |
| query = self.norm_q(query) | ||
| key = self.norm_k(key) | ||
|
|
||
| if image_rotary_emb is not None: | ||
| forward_ctx = get_forward_context() | ||
| text_seq_len = kwargs.get("text_seq_len", None) | ||
| use_sp_single_stream = ( | ||
| forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None | ||
| ) | ||
|
|
||
| if use_sp_single_stream and image_rotary_emb is not None: | ||
| cos, sin = image_rotary_emb | ||
| cos = cos.to(query.dtype) | ||
| sin = sin.to(query.dtype) | ||
| query = self.rope(query, cos, sin) | ||
| key = self.rope(key, cos, sin) | ||
| txt_cos, img_cos = cos[:text_seq_len], cos[text_seq_len:] | ||
| txt_sin, img_sin = sin[:text_seq_len], sin[text_seq_len:] | ||
|
|
||
| img_query = query[:, text_seq_len:] | ||
| img_key = key[:, text_seq_len:] | ||
| img_value = value[:, text_seq_len:] | ||
| text_query = query[:, :text_seq_len] | ||
| text_key = key[:, :text_seq_len] | ||
| text_value = value[:, :text_seq_len] | ||
|
|
||
| img_query = self.rope(img_query, img_cos, img_sin) | ||
| img_key = self.rope(img_key, img_cos, img_sin) | ||
| text_query = self.rope(text_query, txt_cos, txt_sin) | ||
| text_key = self.rope(text_key, txt_cos, txt_sin) | ||
|
|
||
| attn_metadata = AttentionMetadata( | ||
| joint_query=text_query, | ||
| joint_key=text_key, | ||
| joint_value=text_value, | ||
| joint_strategy="front", | ||
| ) | ||
| if attention_mask is not None: | ||
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.unsqueeze(1) | ||
| attn_metadata.attn_mask = attention_mask | ||
|
|
||
| attn_output = self.attn(img_query, img_key, img_value, attn_metadata) | ||
| else: | ||
| if image_rotary_emb is not None: | ||
| cos, sin = image_rotary_emb | ||
| cos = cos.to(query.dtype) | ||
| sin = sin.to(query.dtype) | ||
| query = self.rope(query, cos, sin) | ||
| key = self.rope(key, cos, sin) | ||
|
|
||
| attn_metadata = None | ||
| if attention_mask is not None: | ||
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.unsqueeze(1) | ||
| attn_metadata = AttentionMetadata(attn_mask=attention_mask) | ||
|
|
||
| attn_metadata = None | ||
| if attention_mask is not None: | ||
| if attention_mask.dim() == 3: | ||
| attention_mask = attention_mask.unsqueeze(1) | ||
| attn_metadata = AttentionMetadata(attn_mask=attention_mask) | ||
| attn_output = self.attn(query, key, value, attn_metadata) | ||
|
|
||
| attn_output = self.attn(query, key, value, attn_metadata) | ||
| attn_output = attn_output.flatten(2, 3).to(query.dtype) | ||
|
|
||
| mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states) | ||
|
|
@@ -383,6 +481,13 @@ def forward( | |
| split_hidden_states: bool = False, | ||
| text_seq_len: int | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Forward pass for Flux2SingleTransformerBlock with SP support. | ||
|
|
||
| In SP mode: image hidden_states is chunked (B, img_len/SP, D), | ||
| text encoder_hidden_states is full (B, txt_len, D). | ||
| The block concatenates them for joint attention. | ||
| """ | ||
| if encoder_hidden_states is not None: | ||
| text_seq_len = encoder_hidden_states.shape[1] | ||
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | ||
|
|
@@ -525,6 +630,32 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| return freqs_cos, freqs_sin | ||
|
|
||
|
|
||
| class Flux2RopePrepare(nn.Module): | ||
| """Prepares hidden_states and RoPE embeddings for sequence parallel. | ||
|
|
||
| This module encapsulates the input projection and RoPE computation for Flux.2-klein. | ||
| The key insight is that hidden_states and img_freqs must be sharded together | ||
| to maintain dimension alignment for RoPE computation in attention layers. | ||
| txt_freqs is kept replicated for dual-stream joint attention. | ||
| """ | ||
|
|
||
| def __init__(self, x_embedder: nn.Linear, pos_embed: Flux2PosEmbed): | ||
| super().__init__() | ||
| self.x_embedder = x_embedder | ||
| self.pos_embed = pos_embed | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| img_ids: torch.Tensor, | ||
| txt_ids: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| hidden_states = self.x_embedder(hidden_states) | ||
| img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) | ||
| txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) | ||
| return hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin | ||
|
|
||
|
|
||
| class Flux2TimestepGuidanceEmbeddings(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
@@ -580,6 +711,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, | |
| class Flux2Transformer2DModel(nn.Module): | ||
| """ | ||
| The Transformer model introduced in Flux 2. | ||
|
|
||
| Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig. | ||
| """ | ||
|
|
||
| _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] | ||
|
|
@@ -588,6 +721,16 @@ class Flux2Transformer2DModel(nn.Module): | |
| "add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"], | ||
| } | ||
|
|
||
| _sp_plan = { | ||
| "rope_prepare": { | ||
| 0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True), | ||
| 3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), | ||
| 4: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), | ||
| }, | ||
| "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), | ||
| } | ||
| """SP plan: shard hidden_states/img_freqs at rope_prepare, gather output at proj_out.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| patch_size: int = 1, | ||
|
|
@@ -604,6 +747,7 @@ def __init__( | |
| rope_theta: int = 2000, | ||
| eps: float = 1e-6, | ||
| guidance_embeds: bool = True, | ||
| od_config: OmniDiffusionConfig = None, | ||
| quant_config: "QuantizationConfig | None" = None, | ||
| ): | ||
| super().__init__() | ||
|
|
@@ -626,6 +770,13 @@ def __init__( | |
| guidance_embeds=guidance_embeds, | ||
| ) | ||
|
|
||
| if od_config is not None: | ||
| self.parallel_config = od_config.parallel_config | ||
| else: | ||
| from vllm_omni.diffusion.data import DiffusionParallelConfig | ||
|
|
||
| self.parallel_config = DiffusionParallelConfig() | ||
|
|
||
| self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) | ||
| self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( | ||
| in_channels=timestep_guidance_channels, | ||
|
|
@@ -641,6 +792,8 @@ def __init__( | |
| self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) | ||
| self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) | ||
|
|
||
| self.rope_prepare = Flux2RopePrepare(self.x_embedder, self.pos_embed) | ||
|
|
||
| self.transformer_blocks = nn.ModuleList( | ||
| [ | ||
| Flux2TransformerBlock( | ||
|
|
@@ -699,6 +852,8 @@ def forward( | |
|
|
||
| num_txt_tokens = encoder_hidden_states.shape[1] | ||
|
|
||
| get_forward_context().split_text_embed_in_sp = False | ||
|
|
||
| timestep = timestep.to(hidden_states.dtype) * 1000 | ||
| if guidance is not None: | ||
| guidance = guidance.to(hidden_states.dtype) * 1000 | ||
|
|
@@ -709,21 +864,41 @@ def forward( | |
| double_stream_mod_txt = self.double_stream_modulation_txt(temb) | ||
| single_stream_mod = self.single_stream_modulation(temb)[0] | ||
|
|
||
| hidden_states = self.x_embedder(hidden_states) | ||
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | ||
|
|
||
| if img_ids.ndim == 3: | ||
| img_ids = img_ids[0] | ||
| if txt_ids.ndim == 3: | ||
| txt_ids = txt_ids[0] | ||
|
|
||
| image_rotary_emb = self.pos_embed(img_ids) | ||
| text_rotary_emb = self.pos_embed(txt_ids) | ||
| hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin = self.rope_prepare( | ||
| hidden_states, img_ids, txt_ids | ||
| ) | ||
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) | ||
|
|
||
| concat_rotary_emb = ( | ||
| torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), | ||
| torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), | ||
| torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), | ||
| torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), | ||
| ) | ||
|
|
||
| hidden_states_mask = None | ||
| ctx = get_forward_context() | ||
| if ctx.sp_active: | ||
|
||
| if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0: | ||
| batch_size = hidden_states.shape[0] | ||
| img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size | ||
| full_seq_len = num_txt_tokens + img_padded_seq_len | ||
| hidden_states_mask = torch.ones( | ||
| batch_size, | ||
| full_seq_len, | ||
| dtype=torch.bool, | ||
| device=hidden_states.device, | ||
| ) | ||
| hidden_states_mask[:, num_txt_tokens + ctx.sp_original_seq_len :] = False | ||
| if hidden_states_mask.all(): | ||
| hidden_states_mask = None | ||
|
|
||
| if hidden_states_mask is not None: | ||
| joint_attention_kwargs["attention_mask"] = hidden_states_mask | ||
|
|
||
| for block in self.transformer_blocks: | ||
| encoder_hidden_states, hidden_states = block( | ||
| hidden_states=hidden_states, | ||
|
|
@@ -743,6 +918,7 @@ def forward( | |
| temb_mod_params=single_stream_mod, | ||
| image_rotary_emb=concat_rotary_emb, | ||
| joint_attention_kwargs=joint_attention_kwargs, | ||
| text_seq_len=num_txt_tokens, | ||
| ) | ||
|
|
||
| hidden_states = hidden_states[:, num_txt_tokens:, ...] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these added lines from 397 to 410 independent of any diffusion model and can be extracted?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've extracted the RoPE logic into a reusable helper function(only verified flux1, flux2, and z-image so far, but the others (e.g., ovis-image) should be reusable as well.), and SP-specific logic in flux2_klein (lines 393-410) remains unchanged as it handles model-specific sequence parallelism splitting.