4141from vllm_omni .diffusion .attention .backends .abstract import AttentionMetadata
4242from vllm_omni .diffusion .attention .layer import Attention
4343from vllm_omni .diffusion .data import OmniDiffusionConfig
44- from vllm_omni .diffusion .distributed .parallel_state import (
45- get_sp_group ,
44+ from vllm_omni .diffusion .distributed .sp_plan import (
45+ SequenceParallelInput ,
46+ SequenceParallelOutput ,
4647)
47- from vllm_omni .diffusion .distributed .sp_sharding import sp_shard_with_padding
4848from vllm_omni .diffusion .forward_context import get_forward_context
4949from vllm_omni .diffusion .layers .rope import RotaryEmbedding
5050
@@ -216,33 +216,84 @@ def forward(
216216 encoder_query = self .norm_added_q (encoder_query )
217217 encoder_key = self .norm_added_k (encoder_key )
218218
219- query = torch .cat ([encoder_query , query ], dim = 1 )
220- key = torch .cat ([encoder_key , key ], dim = 1 )
221- value = torch .cat ([encoder_value , value ], dim = 1 )
222-
223- if image_rotary_emb is not None :
224- cos , sin = image_rotary_emb
225- cos = cos .to (query .dtype )
226- sin = sin .to (query .dtype )
227- query = self .rope (query , cos , sin )
228- key = self .rope (key , cos , sin )
229-
230- attn_metadata = None
231- if attention_mask is not None :
232- if attention_mask .dim () == 3 :
233- attention_mask = attention_mask .unsqueeze (1 )
234- attn_metadata = AttentionMetadata (attn_mask = attention_mask )
235-
236- hidden_states = self .attn (query , key , value , attn_metadata )
237- hidden_states = hidden_states .flatten (2 , 3 ).to (query .dtype )
238-
239- if encoder_hidden_states is not None :
240- context_len = encoder_hidden_states .shape [1 ]
241- encoder_hidden_states , hidden_states = hidden_states .split_with_sizes (
242- [context_len , hidden_states .shape [1 ] - context_len ],
243- dim = 1 ,
244- )
245- encoder_hidden_states = self .to_add_out (encoder_hidden_states )
219+ forward_ctx = get_forward_context ()
220+ use_sp_joint_attention = forward_ctx .sp_active and not forward_ctx .split_text_embed_in_sp
221+
222+ if use_sp_joint_attention and image_rotary_emb is not None :
223+ cos , sin = image_rotary_emb
224+ cos = cos .to (query .dtype )
225+ sin = sin .to (query .dtype )
226+ txt_len = encoder_query .shape [1 ]
227+ txt_cos , img_cos = cos [:txt_len ], cos [txt_len :]
228+ txt_sin , img_sin = sin [:txt_len ], sin [txt_len :]
229+ query = self .rope (query , img_cos , img_sin )
230+ key = self .rope (key , img_cos , img_sin )
231+ encoder_query = self .rope (encoder_query , txt_cos , txt_sin )
232+ encoder_key = self .rope (encoder_key , txt_cos , txt_sin )
233+
234+ attn_metadata = AttentionMetadata (
235+ joint_query = encoder_query ,
236+ joint_key = encoder_key ,
237+ joint_value = encoder_value ,
238+ joint_strategy = "front" ,
239+ )
240+ if attention_mask is not None :
241+ if attention_mask .dim () == 3 :
242+ attention_mask = attention_mask .unsqueeze (1 )
243+ attn_metadata .attn_mask = attention_mask
244+
245+ hidden_states = self .attn (query , key , value , attn_metadata )
246+ hidden_states = hidden_states .flatten (2 , 3 ).to (query .dtype )
247+
248+ txt_len = encoder_hidden_states .shape [1 ]
249+ encoder_hidden_states , hidden_states = hidden_states .split_with_sizes (
250+ [txt_len , hidden_states .shape [1 ] - txt_len ],
251+ dim = 1 ,
252+ )
253+ encoder_hidden_states = self .to_add_out (encoder_hidden_states )
254+ else :
255+ query = torch .cat ([encoder_query , query ], dim = 1 )
256+ key = torch .cat ([encoder_key , key ], dim = 1 )
257+ value = torch .cat ([encoder_value , value ], dim = 1 )
258+
259+ if image_rotary_emb is not None :
260+ cos , sin = image_rotary_emb
261+ cos = cos .to (query .dtype )
262+ sin = sin .to (query .dtype )
263+ query = self .rope (query , cos , sin )
264+ key = self .rope (key , cos , sin )
265+
266+ attn_metadata = None
267+ if attention_mask is not None :
268+ if attention_mask .dim () == 3 :
269+ attention_mask = attention_mask .unsqueeze (1 )
270+ attn_metadata = AttentionMetadata (attn_mask = attention_mask )
271+
272+ hidden_states = self .attn (query , key , value , attn_metadata )
273+ hidden_states = hidden_states .flatten (2 , 3 ).to (query .dtype )
274+
275+ context_len = encoder_hidden_states .shape [1 ]
276+ encoder_hidden_states , hidden_states = hidden_states .split_with_sizes (
277+ [context_len , hidden_states .shape [1 ] - context_len ],
278+ dim = 1 ,
279+ )
280+ encoder_hidden_states = self .to_add_out (encoder_hidden_states )
281+ else :
282+ if image_rotary_emb is not None :
283+ cos , sin = image_rotary_emb
284+ cos = cos .to (query .dtype )
285+ sin = sin .to (query .dtype )
286+ query = self .rope (query , cos , sin )
287+ key = self .rope (key , cos , sin )
288+
289+ attn_metadata = None
290+ if attention_mask is not None :
291+ if attention_mask .dim () == 3 :
292+ attention_mask = attention_mask .unsqueeze (1 )
293+ attn_metadata = AttentionMetadata (attn_mask = attention_mask )
294+
295+ hidden_states = self .attn (query , key , value , attn_metadata )
296+ hidden_states = hidden_states .flatten (2 , 3 ).to (query .dtype )
246297
247298 hidden_states = self .to_out [0 ](hidden_states )
248299 hidden_states = self .to_out [1 ](hidden_states )
@@ -333,20 +384,59 @@ def forward(
333384 query = self .norm_q (query )
334385 key = self .norm_k (key )
335386
336- if image_rotary_emb is not None :
387+ forward_ctx = get_forward_context ()
388+ text_seq_len = kwargs .get ("text_seq_len" , None )
389+ use_sp_single_stream = (
390+ forward_ctx .sp_active and not forward_ctx .split_text_embed_in_sp and text_seq_len is not None
391+ )
392+
393+ if use_sp_single_stream and image_rotary_emb is not None :
337394 cos , sin = image_rotary_emb
338395 cos = cos .to (query .dtype )
339396 sin = sin .to (query .dtype )
340- query = self .rope (query , cos , sin )
341- key = self .rope (key , cos , sin )
397+ txt_cos , img_cos = cos [:text_seq_len ], cos [text_seq_len :]
398+ txt_sin , img_sin = sin [:text_seq_len ], sin [text_seq_len :]
399+
400+ img_query = query [:, text_seq_len :]
401+ img_key = key [:, text_seq_len :]
402+ img_value = value [:, text_seq_len :]
403+ text_query = query [:, :text_seq_len ]
404+ text_key = key [:, :text_seq_len ]
405+ text_value = value [:, :text_seq_len ]
406+
407+ img_query = self .rope (img_query , img_cos , img_sin )
408+ img_key = self .rope (img_key , img_cos , img_sin )
409+ text_query = self .rope (text_query , txt_cos , txt_sin )
410+ text_key = self .rope (text_key , txt_cos , txt_sin )
411+
412+ attn_metadata = AttentionMetadata (
413+ joint_query = text_query ,
414+ joint_key = text_key ,
415+ joint_value = text_value ,
416+ joint_strategy = "front" ,
417+ )
418+ if attention_mask is not None :
419+ if attention_mask .dim () == 3 :
420+ attention_mask = attention_mask .unsqueeze (1 )
421+ attn_metadata .attn_mask = attention_mask
422+
423+ attn_output = self .attn (img_query , img_key , img_value , attn_metadata )
424+ else :
425+ if image_rotary_emb is not None :
426+ cos , sin = image_rotary_emb
427+ cos = cos .to (query .dtype )
428+ sin = sin .to (query .dtype )
429+ query = self .rope (query , cos , sin )
430+ key = self .rope (key , cos , sin )
342431
343- attn_metadata = None
344- if attention_mask is not None :
345- if attention_mask .dim () == 3 :
346- attention_mask = attention_mask .unsqueeze (1 )
347- attn_metadata = AttentionMetadata (attn_mask = attention_mask )
432+ attn_metadata = None
433+ if attention_mask is not None :
434+ if attention_mask .dim () == 3 :
435+ attention_mask = attention_mask .unsqueeze (1 )
436+ attn_metadata = AttentionMetadata (attn_mask = attention_mask )
437+
438+ attn_output = self .attn (query , key , value , attn_metadata )
348439
349- attn_output = self .attn (query , key , value , attn_metadata )
350440 attn_output = attn_output .flatten (2 , 3 ).to (query .dtype )
351441
352442 mlp_hidden_states = self .mlp_act_fn (mlp_hidden_states )
@@ -356,12 +446,6 @@ def forward(
356446
357447
358448class Flux2SingleTransformerBlock (nn .Module ):
359- """
360- Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support.
361-
362- SP handling is delegated to Flux2Attention via the forward context.
363- """
364-
365449 def __init__ (
366450 self ,
367451 dim : int ,
@@ -546,6 +630,32 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
546630 return freqs_cos , freqs_sin
547631
548632
633+ class Flux2RopePrepare (nn .Module ):
634+ """Prepares hidden_states and RoPE embeddings for sequence parallel.
635+
636+ This module encapsulates the input projection and RoPE computation for Flux.2-klein.
637+ The key insight is that hidden_states and img_freqs must be sharded together
638+ to maintain dimension alignment for RoPE computation in attention layers.
639+ txt_freqs is kept replicated for dual-stream joint attention.
640+ """
641+
642+ def __init__ (self , x_embedder : nn .Linear , pos_embed : Flux2PosEmbed ):
643+ super ().__init__ ()
644+ self .x_embedder = x_embedder
645+ self .pos_embed = pos_embed
646+
647+ def forward (
648+ self ,
649+ hidden_states : torch .Tensor ,
650+ img_ids : torch .Tensor ,
651+ txt_ids : torch .Tensor ,
652+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
653+ hidden_states = self .x_embedder (hidden_states )
654+ img_freqs_cos , img_freqs_sin = self .pos_embed (img_ids )
655+ txt_freqs_cos , txt_freqs_sin = self .pos_embed (txt_ids )
656+ return hidden_states , txt_freqs_cos , txt_freqs_sin , img_freqs_cos , img_freqs_sin
657+
658+
549659class Flux2TimestepGuidanceEmbeddings (nn .Module ):
550660 def __init__ (
551661 self ,
@@ -611,6 +721,16 @@ class Flux2Transformer2DModel(nn.Module):
611721 "add_kv_proj" : ["add_q_proj" , "add_k_proj" , "add_v_proj" ],
612722 }
613723
724+ _sp_plan = {
725+ "rope_prepare" : {
726+ 0 : SequenceParallelInput (split_dim = 1 , expected_dims = 3 , split_output = True , auto_pad = True ),
727+ 3 : SequenceParallelInput (split_dim = 0 , expected_dims = 2 , split_output = True , auto_pad = True ),
728+ 4 : SequenceParallelInput (split_dim = 0 , expected_dims = 2 , split_output = True , auto_pad = True ),
729+ },
730+ "proj_out" : SequenceParallelOutput (gather_dim = 1 , expected_dims = 3 ),
731+ }
732+ """SP plan: shard hidden_states/img_freqs at rope_prepare, gather output at proj_out."""
733+
614734 def __init__ (
615735 self ,
616736 patch_size : int = 1 ,
@@ -672,6 +792,8 @@ def __init__(
672792 self .x_embedder = nn .Linear (in_channels , self .inner_dim , bias = False )
673793 self .context_embedder = nn .Linear (joint_attention_dim , self .inner_dim , bias = False )
674794
795+ self .rope_prepare = Flux2RopePrepare (self .x_embedder , self .pos_embed )
796+
675797 self .transformer_blocks = nn .ModuleList (
676798 [
677799 Flux2TransformerBlock (
@@ -730,11 +852,7 @@ def forward(
730852
731853 num_txt_tokens = encoder_hidden_states .shape [1 ]
732854
733- sp_size = self .parallel_config .sequence_parallel_size
734- sp_pad_size = 0
735- if sp_size > 1 :
736- hidden_states , sp_pad_size = sp_shard_with_padding (hidden_states , dim = 1 )
737- get_forward_context ().split_text_embed_in_sp = False
855+ get_forward_context ().split_text_embed_in_sp = False
738856
739857 timestep = timestep .to (hidden_states .dtype ) * 1000
740858 if guidance is not None :
@@ -746,29 +864,41 @@ def forward(
746864 double_stream_mod_txt = self .double_stream_modulation_txt (temb )
747865 single_stream_mod = self .single_stream_modulation (temb )[0 ]
748866
749- hidden_states = self .x_embedder (hidden_states )
750- encoder_hidden_states = self .context_embedder (encoder_hidden_states )
751-
752867 if img_ids .ndim == 3 :
753868 img_ids = img_ids [0 ]
754869 if txt_ids .ndim == 3 :
755870 txt_ids = txt_ids [0 ]
756871
757- img_freqs_cos , img_freqs_sin = self .pos_embed (img_ids )
758- txt_freqs_cos , txt_freqs_sin = self .pos_embed (txt_ids )
759-
760- if sp_size > 1 :
761- img_freqs_cos , _ = sp_shard_with_padding (img_freqs_cos , dim = 0 )
762- img_freqs_sin , _ = sp_shard_with_padding (img_freqs_sin , dim = 0 )
763- if get_forward_context ().split_text_embed_in_sp :
764- txt_freqs_cos , _ = sp_shard_with_padding (txt_freqs_cos , dim = 0 )
765- txt_freqs_sin , _ = sp_shard_with_padding (txt_freqs_sin , dim = 0 )
872+ hidden_states , txt_freqs_cos , txt_freqs_sin , img_freqs_cos , img_freqs_sin = self .rope_prepare (
873+ hidden_states , img_ids , txt_ids
874+ )
875+ encoder_hidden_states = self .context_embedder (encoder_hidden_states )
766876
767877 concat_rotary_emb = (
768878 torch .cat ([txt_freqs_cos , img_freqs_cos ], dim = 0 ),
769879 torch .cat ([txt_freqs_sin , img_freqs_sin ], dim = 0 ),
770880 )
771881
882+ hidden_states_mask = None
883+ ctx = get_forward_context ()
884+ if ctx .sp_active :
885+ if ctx .sp_original_seq_len is not None and ctx .sp_padding_size > 0 :
886+ batch_size = hidden_states .shape [0 ]
887+ img_padded_seq_len = ctx .sp_original_seq_len + ctx .sp_padding_size
888+ full_seq_len = num_txt_tokens + img_padded_seq_len
889+ hidden_states_mask = torch .ones (
890+ batch_size ,
891+ full_seq_len ,
892+ dtype = torch .bool ,
893+ device = hidden_states .device ,
894+ )
895+ hidden_states_mask [:, num_txt_tokens + ctx .sp_original_seq_len :] = False
896+ if hidden_states_mask .all ():
897+ hidden_states_mask = None
898+
899+ if hidden_states_mask is not None :
900+ joint_attention_kwargs ["attention_mask" ] = hidden_states_mask
901+
772902 for block in self .transformer_blocks :
773903 encoder_hidden_states , hidden_states = block (
774904 hidden_states = hidden_states ,
@@ -788,17 +918,13 @@ def forward(
788918 temb_mod_params = single_stream_mod ,
789919 image_rotary_emb = concat_rotary_emb ,
790920 joint_attention_kwargs = joint_attention_kwargs ,
921+ text_seq_len = num_txt_tokens ,
791922 )
792923
793924 hidden_states = hidden_states [:, num_txt_tokens :, ...]
794925 hidden_states = self .norm_out (hidden_states , temb )
795926 output = self .proj_out (hidden_states )
796927
797- if self .parallel_config .sequence_parallel_size > 1 :
798- output = get_sp_group ().all_gather (output , dim = 1 )
799- if sp_pad_size > 0 :
800- output = output [:, :- sp_pad_size , ...]
801-
802928 if not return_dict :
803929 return (output ,)
804930 return Transformer2DModelOutput (sample = output )
0 commit comments