Skip to content

[Perf] Optimize the Qwen2.5-Omni Model thinker-to-talker-proj with nn.Linear#825

Merged
hsliuustc0106 merged 1 commit intovllm-project:mainfrom
kechengliu97:Qwen2.5-Optimize
Jan 17, 2026
Merged

[Perf] Optimize the Qwen2.5-Omni Model thinker-to-talker-proj with nn.Linear#825
hsliuustc0106 merged 1 commit intovllm-project:mainfrom
kechengliu97:Qwen2.5-Optimize

Conversation

@kechengliu97
Copy link
Contributor

@kechengliu97 kechengliu97 commented Jan 17, 2026

Summary of Changes

In this submission, we replaced the thinker_to_talker_proj from the original ColumnParallelLinear layer with nn.Linear. This change resulted in a significant performance improvement, particularly in the forward pass of the Qwen2_5OmniTalkerForConditionalGeneration model, where the overall execution time was reduced by over 200 microseconds.

Performance Improvements

  • thinker_to_talker_proj latency:

    • Reduced from 283μs 390ns to 223μs 840ns, a decrease of approximately 21%.
  • Overall forward time:

    • The total execution time decreased by more than 200 microseconds, significantly improving inference performance.

Configuration Comparison & Reference

In our comparison with the Transformer prototype, we observed that the corresponding linear layers also use nn.Linear in that model. Based on this observation, the updated implementation aligns with mainstream architectures, ensuring higher compatibility and optimizability.

Precision Verification

Using the same prompt, "What is the origin of the United States?", we obtained identical outputs. Both the text and the wav file generated were almost identical, proving that the accuracy has not been affected by this change.

The output text (wav file read the following paragraph):

Prompt:
<|im_start|>system
You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.<|im_end|>
<|im_start|>user
What is the origin of the United States?<|im_end|>
<|im_start|>assistant

vllm_text_output:
Well, you know, the United States has its roots in European colonization. The first Europeans to settle there were mainly from England. They started colonizing the eastern part of North America in the 17th century. People like the Pilgrims landed at Plymouth in 1620.Over time, these colonies grew and developed their own governments. Eventually, they declared independence from Britain in 1776. And then they formed the United States as we know it today.If you want to know more about specific events during that colonial period or how the states were formed, just let me know.

Conclusion

By replacing ColumnParallelLinear with nn.Linear, we not only improved performance but also ensured consistency with Transformer architecture. This optimization will result in lower latency and higher throughput for various inference tasks, with no loss in precision.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 308583dda7

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +71 to 74
self.thinker_to_talker_proj = nn.Linear(
self.config.embedding_size,
self.config.hidden_size,
bias=True,
gather_output=True,
skip_bias_add=False,
quant_config=quant_config,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep tensor-parallel/quantized linear for projector

Replacing ColumnParallelLinear with nn.Linear drops both tensor-parallel sharding and quant_config handling. In runs with tensor_parallel_size > 1 or quantized checkpoints, this projector will now instantiate full‑precision weights on every rank (no partitioning/quantized kernels), which can inflate memory or prevent quantized weights from loading correctly. Consider keeping the parallel/quantized linear wrapper or providing an equivalent vLLM-aware linear here.

Useful? React with 👍 / 👎.

Copy link
Contributor Author

@kechengliu97 kechengliu97 Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meanwhile, Qwen-2.5-Omni-7B is a small model which almost all kinds of NPU/GPU can hold with only one.

@hsliuustc0106
Copy link
Collaborator

any accuracy difference?

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Jan 17, 2026
@kechengliu97
Copy link
Contributor Author

any accuracy difference?

With the pure text prompt as input, no difference has been observed so far. More importantly, this is what is used in the corresponding part of the Transformer .

Swapped out the custom ColumnParallelLinear layer for a standard nn.Linear in Qwen2_5OmniTalkerForConditionalGeneration. Updated the forward pass to match the new layer's output signature, simplifying the projection step.

Signed-off-by: John Liu BUAA <[email protected]>
Copy link
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@hsliuustc0106 hsliuustc0106 merged commit bb24e07 into vllm-project:main Jan 17, 2026
7 checks passed
erfgss pushed a commit to erfgss/vllm-omni that referenced this pull request Jan 19, 2026
@ZJY0516
Copy link
Collaborator

ZJY0516 commented Jan 19, 2026

@kechengliu97 Do you have any idea that where the performance improvements come from? ColumnParallelLinear call nn.linear inside actually

@kechengliu97
Copy link
Contributor Author

@kechengliu97 Do you have any idea that where the performance improvements come from? ColumnParallelLinear call nn.linear inside actually

Our Device can not transfer the profiling swim lane to the outer Internet, but I noticed that using these like ColumnParallelLinear or RowParallelLinear we found some communication cost, which may cause the higher latency.

with1015 pushed a commit to with1015/vllm-omni that referenced this pull request Jan 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants