diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 47098ff8a1..a3cc6cbbb6 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -89,6 +89,66 @@ def forward( return hidden_states, vid_freqs, txt_freqs +class ModulateIndexPrepare(nn.Module): + """Prepares modulate_index for sequence parallel when zero_cond_t is enabled. + + This module encapsulates the creation of modulate_index tensor, which is used + to select different conditioning parameters (shift/scale/gate) for different + token positions in image editing tasks. + + Similar to Z-Image's UnifiedPrepare and ImageRopePrepare, this creates a module + boundary where _sp_plan can shard the output via split_output=True. + + The modulate_index must be sharded along the sequence dimension to match the + sharded hidden_states in SP mode. + + Note: Our _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism). + """ + + def __init__(self, zero_cond_t: bool = False): + super().__init__() + self.zero_cond_t = zero_cond_t + + def forward( + self, + timestep: torch.Tensor, + img_shapes: list[list[tuple[int, int, int]]], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Prepare timestep and modulate_index for SP. + + Args: + timestep: Timestep tensor [batch] + img_shapes: List of image shape tuples per batch item. + Each item is a list of (frame, height, width) tuples. + For edit models: [[source_shape], [target_shape1, target_shape2, ...]] + + Returns: + timestep: Doubled timestep if zero_cond_t, else original [batch] or [2*batch] + modulate_index: Token condition index [batch, seq_len] if zero_cond_t, else None + - index=0: source image tokens (use normal timestep conditioning) + - index=1: target image tokens (use zero timestep conditioning) + + Note: _sp_plan will shard modulate_index via split_output=True when SP is enabled. + The modulate_index sequence dimension must match hidden_states after sharding. + """ + if self.zero_cond_t: + # Double the timestep: [timestep, timestep * 0] + # This creates two sets of conditioning parameters in AdaLayerNorm + timestep = torch.cat([timestep, timestep * 0], dim=0) + + # Create modulate_index to select conditioning per token position + # - First image (sample[0]): source image, use index=0 (normal timestep) + # - Remaining images (sample[1:]): target images, use index=1 (zero timestep) + modulate_index = torch.tensor( + [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], + device=timestep.device, + dtype=torch.int, + ) + return timestep, modulate_index + + return timestep, None + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, use_additional_t_cond=False): super().__init__() @@ -785,6 +845,12 @@ class QwenImageTransformer2DModel(CachedTransformer): 1: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True), # txt_freqs (index 2) is NOT sharded - kept replicated for dual-stream attention }, + # Shard ModulateIndexPrepare output (modulate_index must be sharded to match hidden_states) + # This is only active when zero_cond_t=True (image editing models) + # Output index 1 is modulate_index [batch, seq_len], needs sharding along dim=1 + "modulate_index_prepare": { + 1: SequenceParallelInput(split_dim=1, expected_dims=2, split_output=True, auto_pad=True), + }, # Gather output at proj_out "proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3), } @@ -848,6 +914,11 @@ def __init__( # This ensures RoPE dimensions align with hidden_states after sharding self.image_rope_prepare = ImageRopePrepare(self.img_in, self.pos_embed) + # ModulateIndexPrepare module for _sp_plan to shard modulate_index + # This ensures modulate_index dimensions align with hidden_states after sharding + # Only active when zero_cond_t=True (image editing models) + self.modulate_index_prepare = ModulateIndexPrepare(zero_cond_t=zero_cond_t) + def forward( self, hidden_states: torch.Tensor, @@ -906,15 +977,10 @@ def forward( # Ensure timestep tensor is on the same device and dtype as hidden_states timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype) - if self.zero_cond_t: - timestep = torch.cat([timestep, timestep * 0], dim=0) - modulate_index = torch.tensor( - [[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes], - device=timestep.device, - dtype=torch.int, - ) - else: - modulate_index = None + # Prepare timestep and modulate_index via ModulateIndexPrepare module + # _sp_plan will shard modulate_index via split_output=True (when zero_cond_t=True) + # This ensures modulate_index sequence dimension matches sharded hidden_states + timestep, modulate_index = self.modulate_index_prepare(timestep, img_shapes) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states)