Skip to content

Commit 8a7e687

Browse files
committed
upd
Signed-off-by: Lancer <[email protected]>
1 parent 780ed85 commit 8a7e687

1 file changed

Lines changed: 193 additions & 67 deletions

File tree

vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py

Lines changed: 193 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@
4141
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
4242
from vllm_omni.diffusion.attention.layer import Attention
4343
from vllm_omni.diffusion.data import OmniDiffusionConfig
44-
from vllm_omni.diffusion.distributed.parallel_state import (
45-
get_sp_group,
44+
from vllm_omni.diffusion.distributed.sp_plan import (
45+
SequenceParallelInput,
46+
SequenceParallelOutput,
4647
)
47-
from vllm_omni.diffusion.distributed.sp_sharding import sp_shard_with_padding
4848
from vllm_omni.diffusion.forward_context import get_forward_context
4949
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
5050

@@ -216,33 +216,84 @@ def forward(
216216
encoder_query = self.norm_added_q(encoder_query)
217217
encoder_key = self.norm_added_k(encoder_key)
218218

219-
query = torch.cat([encoder_query, query], dim=1)
220-
key = torch.cat([encoder_key, key], dim=1)
221-
value = torch.cat([encoder_value, value], dim=1)
222-
223-
if image_rotary_emb is not None:
224-
cos, sin = image_rotary_emb
225-
cos = cos.to(query.dtype)
226-
sin = sin.to(query.dtype)
227-
query = self.rope(query, cos, sin)
228-
key = self.rope(key, cos, sin)
229-
230-
attn_metadata = None
231-
if attention_mask is not None:
232-
if attention_mask.dim() == 3:
233-
attention_mask = attention_mask.unsqueeze(1)
234-
attn_metadata = AttentionMetadata(attn_mask=attention_mask)
235-
236-
hidden_states = self.attn(query, key, value, attn_metadata)
237-
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
238-
239-
if encoder_hidden_states is not None:
240-
context_len = encoder_hidden_states.shape[1]
241-
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
242-
[context_len, hidden_states.shape[1] - context_len],
243-
dim=1,
244-
)
245-
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
219+
forward_ctx = get_forward_context()
220+
use_sp_joint_attention = forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp
221+
222+
if use_sp_joint_attention and image_rotary_emb is not None:
223+
cos, sin = image_rotary_emb
224+
cos = cos.to(query.dtype)
225+
sin = sin.to(query.dtype)
226+
txt_len = encoder_query.shape[1]
227+
txt_cos, img_cos = cos[:txt_len], cos[txt_len:]
228+
txt_sin, img_sin = sin[:txt_len], sin[txt_len:]
229+
query = self.rope(query, img_cos, img_sin)
230+
key = self.rope(key, img_cos, img_sin)
231+
encoder_query = self.rope(encoder_query, txt_cos, txt_sin)
232+
encoder_key = self.rope(encoder_key, txt_cos, txt_sin)
233+
234+
attn_metadata = AttentionMetadata(
235+
joint_query=encoder_query,
236+
joint_key=encoder_key,
237+
joint_value=encoder_value,
238+
joint_strategy="front",
239+
)
240+
if attention_mask is not None:
241+
if attention_mask.dim() == 3:
242+
attention_mask = attention_mask.unsqueeze(1)
243+
attn_metadata.attn_mask = attention_mask
244+
245+
hidden_states = self.attn(query, key, value, attn_metadata)
246+
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
247+
248+
txt_len = encoder_hidden_states.shape[1]
249+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
250+
[txt_len, hidden_states.shape[1] - txt_len],
251+
dim=1,
252+
)
253+
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
254+
else:
255+
query = torch.cat([encoder_query, query], dim=1)
256+
key = torch.cat([encoder_key, key], dim=1)
257+
value = torch.cat([encoder_value, value], dim=1)
258+
259+
if image_rotary_emb is not None:
260+
cos, sin = image_rotary_emb
261+
cos = cos.to(query.dtype)
262+
sin = sin.to(query.dtype)
263+
query = self.rope(query, cos, sin)
264+
key = self.rope(key, cos, sin)
265+
266+
attn_metadata = None
267+
if attention_mask is not None:
268+
if attention_mask.dim() == 3:
269+
attention_mask = attention_mask.unsqueeze(1)
270+
attn_metadata = AttentionMetadata(attn_mask=attention_mask)
271+
272+
hidden_states = self.attn(query, key, value, attn_metadata)
273+
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
274+
275+
context_len = encoder_hidden_states.shape[1]
276+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
277+
[context_len, hidden_states.shape[1] - context_len],
278+
dim=1,
279+
)
280+
encoder_hidden_states = self.to_add_out(encoder_hidden_states)
281+
else:
282+
if image_rotary_emb is not None:
283+
cos, sin = image_rotary_emb
284+
cos = cos.to(query.dtype)
285+
sin = sin.to(query.dtype)
286+
query = self.rope(query, cos, sin)
287+
key = self.rope(key, cos, sin)
288+
289+
attn_metadata = None
290+
if attention_mask is not None:
291+
if attention_mask.dim() == 3:
292+
attention_mask = attention_mask.unsqueeze(1)
293+
attn_metadata = AttentionMetadata(attn_mask=attention_mask)
294+
295+
hidden_states = self.attn(query, key, value, attn_metadata)
296+
hidden_states = hidden_states.flatten(2, 3).to(query.dtype)
246297

247298
hidden_states = self.to_out[0](hidden_states)
248299
hidden_states = self.to_out[1](hidden_states)
@@ -333,20 +384,59 @@ def forward(
333384
query = self.norm_q(query)
334385
key = self.norm_k(key)
335386

336-
if image_rotary_emb is not None:
387+
forward_ctx = get_forward_context()
388+
text_seq_len = kwargs.get("text_seq_len", None)
389+
use_sp_single_stream = (
390+
forward_ctx.sp_active and not forward_ctx.split_text_embed_in_sp and text_seq_len is not None
391+
)
392+
393+
if use_sp_single_stream and image_rotary_emb is not None:
337394
cos, sin = image_rotary_emb
338395
cos = cos.to(query.dtype)
339396
sin = sin.to(query.dtype)
340-
query = self.rope(query, cos, sin)
341-
key = self.rope(key, cos, sin)
397+
txt_cos, img_cos = cos[:text_seq_len], cos[text_seq_len:]
398+
txt_sin, img_sin = sin[:text_seq_len], sin[text_seq_len:]
399+
400+
img_query = query[:, text_seq_len:]
401+
img_key = key[:, text_seq_len:]
402+
img_value = value[:, text_seq_len:]
403+
text_query = query[:, :text_seq_len]
404+
text_key = key[:, :text_seq_len]
405+
text_value = value[:, :text_seq_len]
406+
407+
img_query = self.rope(img_query, img_cos, img_sin)
408+
img_key = self.rope(img_key, img_cos, img_sin)
409+
text_query = self.rope(text_query, txt_cos, txt_sin)
410+
text_key = self.rope(text_key, txt_cos, txt_sin)
411+
412+
attn_metadata = AttentionMetadata(
413+
joint_query=text_query,
414+
joint_key=text_key,
415+
joint_value=text_value,
416+
joint_strategy="front",
417+
)
418+
if attention_mask is not None:
419+
if attention_mask.dim() == 3:
420+
attention_mask = attention_mask.unsqueeze(1)
421+
attn_metadata.attn_mask = attention_mask
422+
423+
attn_output = self.attn(img_query, img_key, img_value, attn_metadata)
424+
else:
425+
if image_rotary_emb is not None:
426+
cos, sin = image_rotary_emb
427+
cos = cos.to(query.dtype)
428+
sin = sin.to(query.dtype)
429+
query = self.rope(query, cos, sin)
430+
key = self.rope(key, cos, sin)
342431

343-
attn_metadata = None
344-
if attention_mask is not None:
345-
if attention_mask.dim() == 3:
346-
attention_mask = attention_mask.unsqueeze(1)
347-
attn_metadata = AttentionMetadata(attn_mask=attention_mask)
432+
attn_metadata = None
433+
if attention_mask is not None:
434+
if attention_mask.dim() == 3:
435+
attention_mask = attention_mask.unsqueeze(1)
436+
attn_metadata = AttentionMetadata(attn_mask=attention_mask)
437+
438+
attn_output = self.attn(query, key, value, attn_metadata)
348439

349-
attn_output = self.attn(query, key, value, attn_metadata)
350440
attn_output = attn_output.flatten(2, 3).to(query.dtype)
351441

352442
mlp_hidden_states = self.mlp_act_fn(mlp_hidden_states)
@@ -356,12 +446,6 @@ def forward(
356446

357447

358448
class Flux2SingleTransformerBlock(nn.Module):
359-
"""
360-
Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support.
361-
362-
SP handling is delegated to Flux2Attention via the forward context.
363-
"""
364-
365449
def __init__(
366450
self,
367451
dim: int,
@@ -546,6 +630,32 @@ def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
546630
return freqs_cos, freqs_sin
547631

548632

633+
class Flux2RopePrepare(nn.Module):
634+
"""Prepares hidden_states and RoPE embeddings for sequence parallel.
635+
636+
This module encapsulates the input projection and RoPE computation for Flux.2-klein.
637+
The key insight is that hidden_states and img_freqs must be sharded together
638+
to maintain dimension alignment for RoPE computation in attention layers.
639+
txt_freqs is kept replicated for dual-stream joint attention.
640+
"""
641+
642+
def __init__(self, x_embedder: nn.Linear, pos_embed: Flux2PosEmbed):
643+
super().__init__()
644+
self.x_embedder = x_embedder
645+
self.pos_embed = pos_embed
646+
647+
def forward(
648+
self,
649+
hidden_states: torch.Tensor,
650+
img_ids: torch.Tensor,
651+
txt_ids: torch.Tensor,
652+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
653+
hidden_states = self.x_embedder(hidden_states)
654+
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids)
655+
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids)
656+
return hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin
657+
658+
549659
class Flux2TimestepGuidanceEmbeddings(nn.Module):
550660
def __init__(
551661
self,
@@ -611,6 +721,16 @@ class Flux2Transformer2DModel(nn.Module):
611721
"add_kv_proj": ["add_q_proj", "add_k_proj", "add_v_proj"],
612722
}
613723

724+
_sp_plan = {
725+
"rope_prepare": {
726+
0: SequenceParallelInput(split_dim=1, expected_dims=3, split_output=True, auto_pad=True),
727+
3: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
728+
4: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
729+
},
730+
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
731+
}
732+
"""SP plan: shard hidden_states/img_freqs at rope_prepare, gather output at proj_out."""
733+
614734
def __init__(
615735
self,
616736
patch_size: int = 1,
@@ -672,6 +792,8 @@ def __init__(
672792
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
673793
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
674794

795+
self.rope_prepare = Flux2RopePrepare(self.x_embedder, self.pos_embed)
796+
675797
self.transformer_blocks = nn.ModuleList(
676798
[
677799
Flux2TransformerBlock(
@@ -730,11 +852,7 @@ def forward(
730852

731853
num_txt_tokens = encoder_hidden_states.shape[1]
732854

733-
sp_size = self.parallel_config.sequence_parallel_size
734-
sp_pad_size = 0
735-
if sp_size > 1:
736-
hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1)
737-
get_forward_context().split_text_embed_in_sp = False
855+
get_forward_context().split_text_embed_in_sp = False
738856

739857
timestep = timestep.to(hidden_states.dtype) * 1000
740858
if guidance is not None:
@@ -746,29 +864,41 @@ def forward(
746864
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
747865
single_stream_mod = self.single_stream_modulation(temb)[0]
748866

749-
hidden_states = self.x_embedder(hidden_states)
750-
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
751-
752867
if img_ids.ndim == 3:
753868
img_ids = img_ids[0]
754869
if txt_ids.ndim == 3:
755870
txt_ids = txt_ids[0]
756871

757-
img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids)
758-
txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids)
759-
760-
if sp_size > 1:
761-
img_freqs_cos, _ = sp_shard_with_padding(img_freqs_cos, dim=0)
762-
img_freqs_sin, _ = sp_shard_with_padding(img_freqs_sin, dim=0)
763-
if get_forward_context().split_text_embed_in_sp:
764-
txt_freqs_cos, _ = sp_shard_with_padding(txt_freqs_cos, dim=0)
765-
txt_freqs_sin, _ = sp_shard_with_padding(txt_freqs_sin, dim=0)
872+
hidden_states, txt_freqs_cos, txt_freqs_sin, img_freqs_cos, img_freqs_sin = self.rope_prepare(
873+
hidden_states, img_ids, txt_ids
874+
)
875+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
766876

767877
concat_rotary_emb = (
768878
torch.cat([txt_freqs_cos, img_freqs_cos], dim=0),
769879
torch.cat([txt_freqs_sin, img_freqs_sin], dim=0),
770880
)
771881

882+
hidden_states_mask = None
883+
ctx = get_forward_context()
884+
if ctx.sp_active:
885+
if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
886+
batch_size = hidden_states.shape[0]
887+
img_padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
888+
full_seq_len = num_txt_tokens + img_padded_seq_len
889+
hidden_states_mask = torch.ones(
890+
batch_size,
891+
full_seq_len,
892+
dtype=torch.bool,
893+
device=hidden_states.device,
894+
)
895+
hidden_states_mask[:, num_txt_tokens + ctx.sp_original_seq_len :] = False
896+
if hidden_states_mask.all():
897+
hidden_states_mask = None
898+
899+
if hidden_states_mask is not None:
900+
joint_attention_kwargs["attention_mask"] = hidden_states_mask
901+
772902
for block in self.transformer_blocks:
773903
encoder_hidden_states, hidden_states = block(
774904
hidden_states=hidden_states,
@@ -788,17 +918,13 @@ def forward(
788918
temb_mod_params=single_stream_mod,
789919
image_rotary_emb=concat_rotary_emb,
790920
joint_attention_kwargs=joint_attention_kwargs,
921+
text_seq_len=num_txt_tokens,
791922
)
792923

793924
hidden_states = hidden_states[:, num_txt_tokens:, ...]
794925
hidden_states = self.norm_out(hidden_states, temb)
795926
output = self.proj_out(hidden_states)
796927

797-
if self.parallel_config.sequence_parallel_size > 1:
798-
output = get_sp_group().all_gather(output, dim=1)
799-
if sp_pad_size > 0:
800-
output = output[:, :-sp_pad_size, ...]
801-
802928
if not return_dict:
803929
return (output,)
804930
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)