-
Notifications
You must be signed in to change notification settings - Fork 518
[Feat] support SP for FLUX.2-klein #1250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
e8fb739
f2daa52
225ec8a
c57c866
b33ce10
7166799
780ed85
8a7e687
63c4658
219d4f7
876d492
ec5f9f2
6a1e5f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| ) | ||
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | ||
| from diffusers.models.normalization import AdaLayerNormContinuous | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
|
|
@@ -39,8 +40,18 @@ | |
|
|
||
| from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata | ||
| from vllm_omni.diffusion.attention.layer import Attention | ||
| from vllm_omni.diffusion.data import OmniDiffusionConfig | ||
| from vllm_omni.diffusion.distributed.parallel_state import ( | ||
| get_sequence_parallel_rank, | ||
| get_sequence_parallel_world_size, | ||
| get_sp_group, | ||
| ) | ||
| from vllm_omni.diffusion.distributed.sp_sharding import sp_shard_with_padding | ||
| from vllm_omni.diffusion.forward_context import get_forward_context | ||
| from vllm_omni.diffusion.layers.rope import RotaryEmbedding | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class Flux2SwiGLU(nn.Module): | ||
| """SwiGLU activation used by Flux2.""" | ||
|
|
@@ -334,6 +345,12 @@ def forward( | |
|
|
||
|
|
||
| class Flux2SingleTransformerBlock(nn.Module): | ||
| """ | ||
| Single-stream Transformer block for Flux 2 with SP (Sequence Parallelism) support. | ||
|
|
||
| SP handling is delegated to Flux2Attention via the forward context. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dim: int, | ||
|
|
@@ -367,6 +384,13 @@ def forward( | |
| split_hidden_states: bool = False, | ||
| text_seq_len: int | None = None, | ||
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Forward pass for Flux2SingleTransformerBlock with SP support. | ||
|
|
||
| In SP mode: image hidden_states is chunked (B, img_len/SP, D), | ||
| text encoder_hidden_states is full (B, txt_len, D). | ||
| The block concatenates them for joint attention. | ||
| """ | ||
| if encoder_hidden_states is not None: | ||
| text_seq_len = encoder_hidden_states.shape[1] | ||
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | ||
|
|
@@ -556,6 +580,8 @@ def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, | |
| class Flux2Transformer2DModel(nn.Module): | ||
| """ | ||
| The Transformer model introduced in Flux 2. | ||
|
|
||
| Supports Sequence Parallelism (Ulysses and Ring) when configured via OmniDiffusionConfig. | ||
| """ | ||
|
|
||
| _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] | ||
|
|
@@ -580,6 +606,7 @@ def __init__( | |
| rope_theta: int = 2000, | ||
| eps: float = 1e-6, | ||
| guidance_embeds: bool = True, | ||
| od_config: OmniDiffusionConfig = None, | ||
| ): | ||
| super().__init__() | ||
| self.out_channels = out_channels or in_channels | ||
|
|
@@ -601,6 +628,13 @@ def __init__( | |
| guidance_embeds=guidance_embeds, | ||
| ) | ||
|
|
||
| if od_config is not None: | ||
| self.parallel_config = od_config.parallel_config | ||
| else: | ||
| from vllm_omni.diffusion.data import DiffusionParallelConfig | ||
|
|
||
| self.parallel_config = DiffusionParallelConfig() | ||
|
|
||
| self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=list(axes_dims_rope)) | ||
| self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( | ||
| in_channels=timestep_guidance_channels, | ||
|
|
@@ -672,6 +706,26 @@ def forward( | |
|
|
||
| num_txt_tokens = encoder_hidden_states.shape[1] | ||
|
|
||
| sp_size = self.parallel_config.sequence_parallel_size | ||
| get_forward_context().sequence_parallel_size = sp_size | ||
| sp_pad_size = 0 | ||
| if sp_size > 1: | ||
| sp_world_size = get_sequence_parallel_world_size() | ||
| sp_rank = get_sequence_parallel_rank() | ||
| original_shape = hidden_states.shape | ||
| hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1) | ||
|
||
| get_forward_context().split_text_embed_in_sp = False | ||
| if not hasattr(self, "_sp_forward_logged"): | ||
| self._sp_forward_logged = True | ||
| logger.info( | ||
| f"[Flux2 Transformer] SP enabled: sp_size={sp_size}, world_size={sp_world_size}, " | ||
| f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" | ||
| ) | ||
| else: | ||
| if not hasattr(self, "_sp_forward_logged"): | ||
|
||
| self._sp_forward_logged = True | ||
| logger.info(f"[Flux2 Transformer] SP disabled: sp_size={sp_size}") | ||
|
|
||
| timestep = timestep.to(hidden_states.dtype) * 1000 | ||
| if guidance is not None: | ||
| guidance = guidance.to(hidden_states.dtype) * 1000 | ||
|
|
@@ -690,11 +744,21 @@ def forward( | |
| if txt_ids.ndim == 3: | ||
| txt_ids = txt_ids[0] | ||
|
|
||
| image_rotary_emb = self.pos_embed(img_ids) | ||
| text_rotary_emb = self.pos_embed(txt_ids) | ||
| img_freqs_cos, img_freqs_sin = self.pos_embed(img_ids) | ||
| txt_freqs_cos, txt_freqs_sin = self.pos_embed(txt_ids) | ||
|
|
||
| if sp_size > 1: | ||
| sp_world_size = get_sequence_parallel_world_size() | ||
| sp_rank = get_sequence_parallel_rank() | ||
| img_freqs_cos, _ = sp_shard_with_padding(img_freqs_cos, dim=0) | ||
| img_freqs_sin, _ = sp_shard_with_padding(img_freqs_sin, dim=0) | ||
| if get_forward_context().split_text_embed_in_sp: | ||
| txt_freqs_cos, _ = sp_shard_with_padding(txt_freqs_cos, dim=0) | ||
| txt_freqs_sin, _ = sp_shard_with_padding(txt_freqs_sin, dim=0) | ||
|
|
||
| concat_rotary_emb = ( | ||
| torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), | ||
| torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), | ||
| torch.cat([txt_freqs_cos, img_freqs_cos], dim=0), | ||
| torch.cat([txt_freqs_sin, img_freqs_sin], dim=0), | ||
| ) | ||
|
|
||
| for block in self.transformer_blocks: | ||
|
|
@@ -722,6 +786,11 @@ def forward( | |
| hidden_states = self.norm_out(hidden_states, temb) | ||
| output = self.proj_out(hidden_states) | ||
|
|
||
| if self.parallel_config.sequence_parallel_size > 1: | ||
| output = get_sp_group().all_gather(output, dim=1) | ||
| if sp_pad_size > 0: | ||
| output = output[:, :-sp_pad_size, ...] | ||
|
|
||
| if not return_dict: | ||
| return (output,) | ||
| return Transformer2DModelOutput(sample=output) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fetch the
sp_size, usingself.parallel_config.sequence_parallel_sizewould be sufficient. I don't see why you need to setget_forward_context().sequence_parallel_sizehere.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find that
LongCatImageTransformer2DModelalso editsget_forward_context...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
od_config = get_forward_context().omni_diffusion_configparallel_config = od_config.parallel_configsequence_parallel_size = parallel_config.sequence_parallel_sizeThis would be my recommendation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ths, I keep only sp_size = self.parallel_config.sequence_parallel_size—that matches what you suggested