Skip to content
Merged
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
e3bf6e8
cp plan framework in vllm-omni
mxuax Jan 14, 2026
e02e841
add partial context split hook
mxuax Jan 15, 2026
4db34f3
fix licenses
mxuax Jan 15, 2026
6e3063c
add test file and modify z-image for cp_plan
mxuax Jan 15, 2026
73882ae
modify z-image cp_plan
mxuax Jan 15, 2026
a039bc8
enable hybrid ulysses and ring, add apply_context_paralle in registry…
mxuax Jan 15, 2026
1abd69a
modify z-image-transformer, created UnifiedPrepare to put all the pre…
mxuax Jan 15, 2026
7f94550
support cp_plan for qwen-image
mxuax Jan 16, 2026
50ccdd1
modify test
mxuax Jan 16, 2026
b640762
add cp_plan doc
mxuax Jan 16, 2026
d1ede83
reduction wan
mxuax Jan 16, 2026
a5cd982
reduction wan from test
mxuax Jan 16, 2026
505d774
fix doc warning
mxuax Jan 16, 2026
d816ec0
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 16, 2026
a52093b
Delete Untitled
mxuax Jan 16, 2026
0714a44
refactor context parallel to sequence parallel and add some sp_plan i…
mxuax Jan 19, 2026
b369786
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 19, 2026
722c9f6
Add attention mask support to _sp_plan framework for variable sequenc…
mxuax Jan 19, 2026
ddbf89e
fix wrongly chunck attention mask issue
mxuax Jan 19, 2026
9e76fc2
fix chunck attention mask bug
mxuax Jan 19, 2026
b723476
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 19, 2026
9f760a2
handle mask in attention metadata
mxuax Jan 19, 2026
d7f4157
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 19, 2026
d9aeca8
remove some declarational comments
mxuax Jan 19, 2026
81516a6
remove some declarational comments
mxuax Jan 19, 2026
9c77428
refactor the sp_plan and sp_config file, removed the training related…
mxuax Jan 19, 2026
ec37009
modified the parallelism_acceleration.md to give a clearer sp_plan in…
mxuax Jan 20, 2026
c09171b
add test for sequence_parallel.py
mxuax Jan 20, 2026
470b737
fix error
mxuax Jan 20, 2026
66e85e8
fix error
mxuax Jan 20, 2026
7976e0f
fix file name error and comment error in test_sequence_parallel.py
mxuax Jan 20, 2026
5502c1b
fix multiple instance error in test_sequence_parallel.py
mxuax Jan 20, 2026
2dcd1a2
fix multiple instance error in test_sequence_parallel.py
mxuax Jan 20, 2026
e6c8533
fix multiple instance error in test_sequence_parallel.py
mxuax Jan 20, 2026
42d012f
fix multiple instance error in test_sequence_parallel.py
mxuax Jan 20, 2026
f8882ec
Merge remote-tracking branch 'origin/main' into workingbranch
mxuax Jan 21, 2026
892e556
remove log print debug codes
mxuax Jan 21, 2026
f63d52c
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 21, 2026
6589099
sync main
mxuax Jan 26, 2026
cfde39c
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 26, 2026
6e7e198
Merge branch 'vllm-project:main' into Non-Intrusive-SP
mxuax Jan 26, 2026
3f09000
add sp support for t2v wan2.2
mxuax Jan 26, 2026
6a0e9ec
add doc
mxuax Jan 26, 2026
4819e56
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 26, 2026
613da52
Apply suggestion from @hsliuustc0106
mxuax Jan 27, 2026
2cef526
fix typo
mxuax Jan 27, 2026
0cd417e
fix typo
mxuax Jan 27, 2026
fa83fc9
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 27, 2026
f0255f1
fix flash—attn backends selection logic
mxuax Jan 27, 2026
97234c2
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 27, 2026
23a6930
Merge branch 'vllm-project:main' into Non-Intrusive-SP
mxuax Jan 27, 2026
2cb9c22
fix corner case lse is None
mxuax Jan 28, 2026
9a04503
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 28, 2026
a9a260d
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 28, 2026
d58dcfc
fix ruff error
mxuax Jan 28, 2026
39e2330
fix ruff error
mxuax Jan 28, 2026
1e06e8f
fix wrong dropout parameter error
mxuax Jan 28, 2026
ac9f4af
fix parameter name error
mxuax Jan 28, 2026
65db163
fix except ImportError
mxuax Jan 28, 2026
738729d
fix except ImportError
mxuax Jan 28, 2026
c796e97
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 28, 2026
47d351b
refactor the backends name
mxuax Jan 28, 2026
26c44ad
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 28, 2026
2a1d87e
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 28, 2026
f4f554f
remove redundant checking
mxuax Jan 28, 2026
4207188
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 28, 2026
c790eb8
remove redundant checking
mxuax Jan 28, 2026
b0c75d0
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 28, 2026
d8b78e5
recover
mxuax Jan 28, 2026
48f55f0
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 28, 2026
d278a30
remove outdated NPU FA backends implementation
mxuax Jan 29, 2026
98649be
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 29, 2026
4f28f44
update the fall back chain in fa.py
mxuax Jan 29, 2026
8694d80
only use fa3_fwd_func in ring_attn to make sure LSE is returned
mxuax Jan 29, 2026
1b4598e
fall back to spda if no fa
mxuax Jan 29, 2026
def2549
fix insufficient check leaded warning in env.py
mxuax Jan 29, 2026
9c46ab4
move backends check to platform.py
mxuax Jan 29, 2026
06a0706
fix typo
mxuax Jan 29, 2026
9dee299
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 29, 2026
d1da324
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 29, 2026
9166732
fix modulate_index handling by adding new submodule
mxuax Jan 30, 2026
05a1d17
Merge branch 'Non-Intrusive-SP' of https://github.com/mxuax/vllm-omni…
mxuax Jan 30, 2026
a5d0234
Update ring_globals.py
mxuax Jan 30, 2026
2c9e3b4
Merge branch 'vllm-project:main' into Non-Intrusive-SP
mxuax Jan 30, 2026
4841b90
Remove empty line at the beginning of ring_globals.py
mxuax Jan 30, 2026
5752a9a
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 30, 2026
7420b52
Merge branch 'main' into Non-Intrusive-SP
mxuax Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,66 @@ def forward(
return hidden_states, vid_freqs, txt_freqs


class ModulateIndexPrepare(nn.Module):
"""Prepares modulate_index for sequence parallel when zero_cond_t is enabled.

This module encapsulates the creation of modulate_index tensor, which is used
to select different conditioning parameters (shift/scale/gate) for different
token positions in image editing tasks.

Similar to Z-Image's UnifiedPrepare and ImageRopePrepare, this creates a module
boundary where _sp_plan can shard the output via split_output=True.

The modulate_index must be sharded along the sequence dimension to match the
sharded hidden_states in SP mode.

Note: Our _sp_plan corresponds to diffusers' _cp_plan (Context Parallelism).
"""

def __init__(self, zero_cond_t: bool = False):
super().__init__()
self.zero_cond_t = zero_cond_t

def forward(
self,
timestep: torch.Tensor,
img_shapes: list[list[tuple[int, int, int]]],
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Prepare timestep and modulate_index for SP.

Args:
timestep: Timestep tensor [batch]
img_shapes: List of image shape tuples per batch item.
Each item is a list of (frame, height, width) tuples.
For edit models: [[source_shape], [target_shape1, target_shape2, ...]]

Returns:
timestep: Doubled timestep if zero_cond_t, else original [batch] or [2*batch]
modulate_index: Token condition index [batch, seq_len] if zero_cond_t, else None
- index=0: source image tokens (use normal timestep conditioning)
- index=1: target image tokens (use zero timestep conditioning)

Note: _sp_plan will shard modulate_index via split_output=True when SP is enabled.
The modulate_index sequence dimension must match hidden_states after sharding.
"""
if self.zero_cond_t:
# Double the timestep: [timestep, timestep * 0]
# This creates two sets of conditioning parameters in AdaLayerNorm
timestep = torch.cat([timestep, timestep * 0], dim=0)

# Create modulate_index to select conditioning per token position
# - First image (sample[0]): source image, use index=0 (normal timestep)
# - Remaining images (sample[1:]): target images, use index=1 (zero timestep)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
device=timestep.device,
dtype=torch.int,
)
return timestep, modulate_index

return timestep, None


class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, use_additional_t_cond=False):
super().__init__()
Expand Down Expand Up @@ -785,6 +845,12 @@ class QwenImageTransformer2DModel(CachedTransformer):
1: SequenceParallelInput(split_dim=0, expected_dims=2, split_output=True, auto_pad=True),
# txt_freqs (index 2) is NOT sharded - kept replicated for dual-stream attention
},
# Shard ModulateIndexPrepare output (modulate_index must be sharded to match hidden_states)
# This is only active when zero_cond_t=True (image editing models)
# Output index 1 is modulate_index [batch, seq_len], needs sharding along dim=1
"modulate_index_prepare": {
1: SequenceParallelInput(split_dim=1, expected_dims=2, split_output=True, auto_pad=True),
},
# Gather output at proj_out
"proj_out": SequenceParallelOutput(gather_dim=1, expected_dims=3),
}
Expand Down Expand Up @@ -848,6 +914,11 @@ def __init__(
# This ensures RoPE dimensions align with hidden_states after sharding
self.image_rope_prepare = ImageRopePrepare(self.img_in, self.pos_embed)

# ModulateIndexPrepare module for _sp_plan to shard modulate_index
# This ensures modulate_index dimensions align with hidden_states after sharding
# Only active when zero_cond_t=True (image editing models)
self.modulate_index_prepare = ModulateIndexPrepare(zero_cond_t=zero_cond_t)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -906,15 +977,10 @@ def forward(
# Ensure timestep tensor is on the same device and dtype as hidden_states
timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype)

if self.zero_cond_t:
timestep = torch.cat([timestep, timestep * 0], dim=0)
modulate_index = torch.tensor(
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in img_shapes],
device=timestep.device,
dtype=torch.int,
)
else:
modulate_index = None
# Prepare timestep and modulate_index via ModulateIndexPrepare module
# _sp_plan will shard modulate_index via split_output=True (when zero_cond_t=True)
# This ensures modulate_index sequence dimension matches sharded hidden_states
timestep, modulate_index = self.modulate_index_prepare(timestep, img_shapes)

encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
Expand Down