Skip to content

Comments

[Feature] add Tensor Parallelism to SD_3.5#1336

Merged
ZJY0516 merged 6 commits intovllm-project:mainfrom
GG-li:feature/sd-tp
Feb 12, 2026
Merged

[Feature] add Tensor Parallelism to SD_3.5#1336
ZJY0516 merged 6 commits intovllm-project:mainfrom
GG-li:feature/sd-tp

Conversation

@GG-li
Copy link
Contributor

@GG-li GG-li commented Feb 11, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Feature for #1217.
Add Tensor Parallelism to Stable Diffusion-3.5.

Test Plan

python text_to_image.py --model stabilityai/stable-diffusion-3.5-medium --prompt "a cup of coffee on the table" --negative_prompt "ugly, unclear" --cfg_scale 4.0 --num_inference_steps 50 --output "tp_enabled.png" --tensor_parallel_size 4

Test Result

Image size: 1024x1024

TP Size Time Generated Image
TP=1 5400.104ms Result 1
TP=2 2733.772ms Result 2
TP=4 2286.651ms Result 4

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: GG-li <3226868735@qq.com>
@GG-li GG-li requested a review from hsliuustc0106 as a code owner February 11, 2026 12:23
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a2c2e8ab32

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@GG-li
Copy link
Contributor Author

GG-li commented Feb 11, 2026

@wtomin @Bounty-hunter Please take a look

Signed-off-by: GG-li <3226868735@qq.com>
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(RowParallelLinear(inner_dim, dim_out, bias=bias))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is input_is_parallel=True required for RowParallelLinear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_is_parallel is True by default. I've fixed the implementation.

# Compute QKV for text stream (context projections)
qkv, _ = self.add_kv_proj(encoder_hidden_states)
txt_query, txt_key, txt_value = qkv.chunk(3, dim=-1)
qkv_add, _ = self.add_kv_proj(encoder_hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.add_kv_proj may not be initialized if added_kv_proj_dim is None.

Needs to run check before calling self.add_kv_proj.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

else:
self.to_out = None

self.norm_added_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if you can replace diffusers' RMSNorm by vllm's:

from vllm.model_executor.layers.layernorm import RMSNorm

I guess vllm's RMSNorm is better, and it is commonly used in vllm-omni

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if final_dropout:
self.net.append(nn.Dropout(dropout))

def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe to can rewrite as follow:
for layer in self.net:
output = layer(hidden_states)
hidden_states = output[0] if isinstance(output, tuple) else output
return hidden_states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

total_num_heads=num_heads,
disable_tp=True,
bias=True,
return_bias=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check for:
The returned bias was not used subsequently, which could easily lead to misunderstanding.
if skip_bias_add is false, bias returns is None, we don't need to return.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@wtomin
Copy link
Contributor

wtomin commented Feb 12, 2026

@GG-li Can you also help to support the online serving text-to-image script with tensor parallel? Sofar the online serving script does not have the argument corresponding to tensor parallel size.

GG-li and others added 2 commits February 12, 2026 14:21
Signed-off-by: GG-li <3226868735@qq.com>
Copy link
Contributor

@wtomin wtomin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@wtomin
Copy link
Contributor

wtomin commented Feb 12, 2026

@hsliuustc0106 @ZJY0516 @SamitHuang I think it is ready to merge.

Copy link
Collaborator

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZJY0516 ZJY0516 added the ready label to trigger buildkite CI label Feb 12, 2026
@ZJY0516 ZJY0516 enabled auto-merge (squash) February 12, 2026 06:54
@ZJY0516 ZJY0516 merged commit 71c9e77 into vllm-project:main Feb 12, 2026
6 of 7 checks passed
YanickSchraner pushed a commit to YanickSchraner/vllm-omni that referenced this pull request Feb 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants