Skip to content

[Feat] support SP for FLUX.2-klein#1250

Open
RuixiangMa wants to merge 13 commits intovllm-project:mainfrom
RuixiangMa:spforflux2klein
Open

[Feat] support SP for FLUX.2-klein#1250
RuixiangMa wants to merge 13 commits intovllm-project:mainfrom
RuixiangMa:spforflux2klein

Conversation

@RuixiangMa
Copy link
Contributor

@RuixiangMa RuixiangMa commented Feb 6, 2026

Purpose

support SP (Ulysses & Ring) for FLUX.2-klein

Test Plan

Test Result

  • Target image:
  • tp = 1 + 4 * nvidia 4090(24G)

curl -s -X POST "http://localhost:8004/v1/images/edits" -F "image=@test.jpg" -F "prompt=Change the sky to orange sunset." -F "guidance_scale=1.0" -F "num_inference_steps=50" -F "n=1" -F "size=1024x1024" -F "output_format=png" | jq -r '.data[0].b64_json' | base64 --decode > output.png

Configuration Ulysses degree Ring degree Generation Time Speedup Images Generated
Baseline 1 1 25.503s 1.00x
Ulysses 4 1 13.173s 1.94x
Ring 1 4 16.866s 1.51x
Hybrid Ulysses + Ring 2 2 14.812s 1.72x

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f3436b8532

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Signed-off-by: Lancer <maruixiang6688@gmail.com>

image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
if current_omni_platform.is_npu():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gcanlin do we have better ways to handle this difference? this is so awkward

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wtomin PTAL

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that there exists npu and mps hardcode in Flux2PosEmbed, which may be from diffusers library I guess. I will take a micro-refactoring PR for removing it. We could delete this npu branch temporarily if we'd like to merge this PR first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that there exists npu and mps hardcode in Flux2PosEmbed, which may be from diffusers library I guess. I will take a micro-refactoring PR for removing it. We could delete this npu branch temporarily if we'd like to merge this PR first.

Got it! I'll just remove this branch right now

Copy link
Collaborator

@ZJY0516 ZJY0516 Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember some of these is because torch_npu doesn't support complex number

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@hsliuustc0106
Copy link
Collaborator

update the docs as well

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Contributor Author

update the docs as well

Done

num_txt_tokens = encoder_hidden_states.shape[1]

sp_size = self.parallel_config.sequence_parallel_size
get_forward_context().sequence_parallel_size = sp_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fetch the sp_size, using self.parallel_config.sequence_parallel_size would be sufficient. I don't see why you need to set get_forward_context().sequence_parallel_size here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that LongCatImageTransformer2DModel also edits get_forward_context...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

od_config = get_forward_context().omni_diffusion_config
parallel_config = od_config.parallel_config
sequence_parallel_size = parallel_config.sequence_parallel_size

This would be my recommendation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fetch the sp_size, using self.parallel_config.sequence_parallel_size would be sufficient. I don't see why you need to set get_forward_context().sequence_parallel_size here.

ths, I keep only sp_size = self.parallel_config.sequence_parallel_size—that matches what you suggested

f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}"
)
else:
if not hasattr(self, "_sp_forward_logged"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self._sp_forward_logged used for debugging only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is self._sp_forward_logged used for debugging only

yes, I'm removing it now.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
sp_size = self.parallel_config.sequence_parallel_size
sp_pad_size = 0
if sp_size > 1:
hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little concerned about the padding behavior here. It looks like flux2_klein does not support hidden_states_mask (like qwen_image), therefore, the padded tokens will participate in the attention computation, which is wrong.

If using Non-intrusive _sp_plan, sp_padding_size will be automatically set in get_forward_context(), however, hidden_states_mask is required to exclude the padded tokens.

# vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
        hidden_states_mask = None  # default
        if self.parallel_config is not None and self.parallel_config.sequence_parallel_size > 1:
            ctx = get_forward_context()
            if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
                # Create mask for the full (padded) sequence
                # valid positions = True, padding positions = False
                batch_size = hidden_states.shape[0]
                padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
                hidden_states_mask = torch.ones(
                    batch_size,
                    padded_seq_len,
                    dtype=torch.bool,
                    device=hidden_states.device,
                )
                hidden_states_mask[:, ctx.sp_original_seq_len :] = False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RuixiangMa I would like to invite you to join the discussion here #1324.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two suggestions:

  1. refer to [Feature] Support Wan2.2 output with irregular shapes #1279 , you may use auto_pad=True and use hidden_states_mask to exclude the padded tokens.
  2. You may set SequenceParallelInput(auto_pad=False), this will raise an error when seq_len not divisible by sp_size, which I think might be rare for Flux.2-klein. We can take care of rare cases in the future, after we make our discussion clear in [RFC]: Ulysses-SP Constraints Solution #1324.

@hsliuustc0106
Copy link
Collaborator

@vllm-omni-reviewer

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Lancer <maruixiang6688@gmail.com>
@wtomin
Copy link
Collaborator

wtomin commented Mar 9, 2026

Do you have plan to support SP with _sp_plan hooks?

@RuixiangMa
Copy link
Contributor Author

Do you have plan to support SP with _sp_plan hooks?

Sorry, just submitted a revision

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Contributor Author

Do you have plan to support SP with _sp_plan hooks?

Sorry, just submitted a revision

@wtomin Can you take a look?

@wtomin
Copy link
Collaborator

wtomin commented Mar 9, 2026

There is a bug related to sp in #1556, I have a bugfix in #1704. Please check if it affects your PR.

sin = sin.to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)
txt_cos, img_cos = cos[:text_seq_len], cos[text_seq_len:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these added lines from 397 to 410 independent of any diffusion model and can be extracted?

Copy link
Contributor Author

@RuixiangMa RuixiangMa Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've extracted the RoPE logic into a reusable helper function(only verified flux1, flux2, and z-image so far, but the others (e.g., ovis-image) should be reusable as well.), and SP-specific logic in flux2_klein (lines 393-410) remains unchanged as it handles model-specific sequence parallelism splitting.


hidden_states_mask = None
ctx = get_forward_context()
if ctx.sp_active:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ctx.sp_active is should not be exposed to users or developers.

In _sp_plan hooks design, ctx.sp_active is determined by _sp_shard_depth. _sp_shard_depth will be automatically handled by hook function. Therefore, no need to include ctx.sp_active here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ctx.sp_active is should not be exposed to users or developers.

In _sp_plan hooks design, ctx.sp_active is determined by _sp_shard_depth. _sp_shard_depth will be automatically handled by hook function. Therefore, no need to include ctx.sp_active here.

It was redundant, I removed it

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Contributor Author

There is a bug related to sp in #1556, I have a bugfix in #1704. Please check if it affects your PR.

Flux2's hook-based impl seems unaffected. I'll merge and verify

@wtomin
Copy link
Collaborator

wtomin commented Mar 10, 2026

I would recommend you to use _sp_plan hooks for sp implementation instead of intrusive modification.

Taking PR #1772 as an example, one major reason that intrusive modification is not good, it affects teacache extractor function, making it over complicated.

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Contributor Author

There is a bug related to sp in #1556, I have a bugfix in #1704. Please check if it affects your PR.

merge the pr and test, it work well.

@wtomin
Copy link
Collaborator

wtomin commented Mar 11, 2026

Have you run offline inference with irreguar shape that involves auto padding and hidden_states_mask?

Signed-off-by: Lancer <maruixiang6688@gmail.com>
@RuixiangMa
Copy link
Contributor Author

Have you run offline inference with irreguar shape that involves auto padding and hidden_states_mask?

Both irreguar and offload tests passed.

Copy link
Collaborator

@wtomin wtomin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants