[Feat] support SP for FLUX.2-klein#1250
[Feat] support SP for FLUX.2-klein#1250RuixiangMa wants to merge 13 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f3436b8532
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Lancer <maruixiang6688@gmail.com>
f3436b8 to
e8fb739
Compare
|
|
||
| image_rotary_emb = self.pos_embed(img_ids) | ||
| text_rotary_emb = self.pos_embed(txt_ids) | ||
| if current_omni_platform.is_npu(): |
There was a problem hiding this comment.
@gcanlin do we have better ways to handle this difference? this is so awkward
There was a problem hiding this comment.
I find that there exists npu and mps hardcode in Flux2PosEmbed, which may be from diffusers library I guess. I will take a micro-refactoring PR for removing it. We could delete this npu branch temporarily if we'd like to merge this PR first.
There was a problem hiding this comment.
I find that there exists
npuandmpshardcode inFlux2PosEmbed, which may be fromdiffuserslibrary I guess. I will take a micro-refactoring PR for removing it. We could delete this npu branch temporarily if we'd like to merge this PR first.
Got it! I'll just remove this branch right now
There was a problem hiding this comment.
I remember some of these is because torch_npu doesn't support complex number
482940d to
225ec8a
Compare
|
update the docs as well |
Done |
| num_txt_tokens = encoder_hidden_states.shape[1] | ||
|
|
||
| sp_size = self.parallel_config.sequence_parallel_size | ||
| get_forward_context().sequence_parallel_size = sp_size |
There was a problem hiding this comment.
To fetch the sp_size, using self.parallel_config.sequence_parallel_size would be sufficient. I don't see why you need to set get_forward_context().sequence_parallel_size here.
There was a problem hiding this comment.
I find that LongCatImageTransformer2DModel also edits get_forward_context...
There was a problem hiding this comment.
od_config = get_forward_context().omni_diffusion_config
parallel_config = od_config.parallel_config
sequence_parallel_size = parallel_config.sequence_parallel_size
This would be my recommendation.
There was a problem hiding this comment.
To fetch the
sp_size, usingself.parallel_config.sequence_parallel_sizewould be sufficient. I don't see why you need to setget_forward_context().sequence_parallel_sizehere.
ths, I keep only sp_size = self.parallel_config.sequence_parallel_size—that matches what you suggested
| f"rank={sp_rank}, original_shape={original_shape}, chunked_shape={hidden_states.shape}" | ||
| ) | ||
| else: | ||
| if not hasattr(self, "_sp_forward_logged"): |
There was a problem hiding this comment.
Is self._sp_forward_logged used for debugging only?
There was a problem hiding this comment.
Is
self._sp_forward_loggedused for debugging only
yes, I'm removing it now.
0b50742 to
b33ce10
Compare
| sp_size = self.parallel_config.sequence_parallel_size | ||
| sp_pad_size = 0 | ||
| if sp_size > 1: | ||
| hidden_states, sp_pad_size = sp_shard_with_padding(hidden_states, dim=1) |
There was a problem hiding this comment.
I am a little concerned about the padding behavior here. It looks like flux2_klein does not support hidden_states_mask (like qwen_image), therefore, the padded tokens will participate in the attention computation, which is wrong.
If using Non-intrusive _sp_plan, sp_padding_size will be automatically set in get_forward_context(), however, hidden_states_mask is required to exclude the padded tokens.
# vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
hidden_states_mask = None # default
if self.parallel_config is not None and self.parallel_config.sequence_parallel_size > 1:
ctx = get_forward_context()
if ctx.sp_original_seq_len is not None and ctx.sp_padding_size > 0:
# Create mask for the full (padded) sequence
# valid positions = True, padding positions = False
batch_size = hidden_states.shape[0]
padded_seq_len = ctx.sp_original_seq_len + ctx.sp_padding_size
hidden_states_mask = torch.ones(
batch_size,
padded_seq_len,
dtype=torch.bool,
device=hidden_states.device,
)
hidden_states_mask[:, ctx.sp_original_seq_len :] = FalseThere was a problem hiding this comment.
@RuixiangMa I would like to invite you to join the discussion here #1324.
There was a problem hiding this comment.
I have two suggestions:
- refer to [Feature] Support Wan2.2 output with irregular shapes #1279 , you may use
auto_pad=Trueand usehidden_states_maskto exclude the padded tokens. - You may set
SequenceParallelInput(auto_pad=False), this will raise an error whenseq_lennot divisible bysp_size, which I think might be rare for Flux.2-klein. We can take care of rare cases in the future, after we make our discussion clear in [RFC]: Ulysses-SP Constraints Solution #1324.
|
@vllm-omni-reviewer |
Signed-off-by: Lancer <maruixiang6688@gmail.com>
|
Do you have plan to support SP with |
Sorry, just submitted a revision |
823dfcf to
8a7e687
Compare
@wtomin Can you take a look? |
| sin = sin.to(query.dtype) | ||
| query = self.rope(query, cos, sin) | ||
| key = self.rope(key, cos, sin) | ||
| txt_cos, img_cos = cos[:text_seq_len], cos[text_seq_len:] |
There was a problem hiding this comment.
are these added lines from 397 to 410 independent of any diffusion model and can be extracted?
There was a problem hiding this comment.
I've extracted the RoPE logic into a reusable helper function(only verified flux1, flux2, and z-image so far, but the others (e.g., ovis-image) should be reusable as well.), and SP-specific logic in flux2_klein (lines 393-410) remains unchanged as it handles model-specific sequence parallelism splitting.
|
|
||
| hidden_states_mask = None | ||
| ctx = get_forward_context() | ||
| if ctx.sp_active: |
There was a problem hiding this comment.
I think ctx.sp_active is should not be exposed to users or developers.
In _sp_plan hooks design, ctx.sp_active is determined by _sp_shard_depth. _sp_shard_depth will be automatically handled by hook function. Therefore, no need to include ctx.sp_active here.
There was a problem hiding this comment.
I think
ctx.sp_activeis should not be exposed to users or developers.In
_sp_planhooks design,ctx.sp_activeis determined by_sp_shard_depth._sp_shard_depthwill be automatically handled by hook function. Therefore, no need to includectx.sp_activehere.
It was redundant, I removed it
|
I would recommend you to use Taking PR #1772 as an example, one major reason that intrusive modification is not good, it affects teacache extractor function, making it over complicated. |
|
Have you run offline inference with irreguar shape that involves auto padding and |
Both irreguar and offload tests passed. |
Purpose
support SP (Ulysses & Ring) for FLUX.2-klein
Test Plan
Test Result
curl -s -X POST "http://localhost:8004/v1/images/edits" -F "image=@test.jpg" -F "prompt=Change the sky to orange sunset." -F "guidance_scale=1.0" -F "num_inference_steps=50" -F "n=1" -F "size=1024x1024" -F "output_format=png" | jq -r '.data[0].b64_json' | base64 --decode > output.png