Skip to content

Conversation

@imkero
Copy link
Contributor

@imkero imkero commented Mar 12, 2025

This PR optimize the tensor generation performance of Qwen2/2.5-VL ViT (including rot_pos_ids, window_indices, cu_seqlens, seqlens), by introducing optimized numba / torch impl.

What this PR do

  1. keep image_grid_thw and video_grid_thw in CPU all the time
    • this prevents a lot of Device to Host Memcpy and CUDA Synchronize, see the profiling result below
  2. bring numba as a common dependencies (up to now it is only required by CUDA / RoCm build)
    • required by following optimization
  3. optimize Qwen2/2.5-VL's rot_pos_ids generation by rewriting the impl with 2 different impl: numba ver (for CPU backend) and torch ver (for other backends, e.g. CUDA)
    • make them runs faster than original torch impl
  4. optimize Qwen2.5-VL's window_indices generation by rewriting the impl with numba + torch together
    • make them runs faster than original torch impl
  5. add tests to ensure the optimized version producing correct results (compared to the original impl)

Benchmark and profiling

Tested on NVIDIA A10, with Intel(R) Xeon(R) Platinum 8372HC CPU @ 3.40GHz 12 cores, using vLLM V1, Flash Attention 2

Profiling result

main branch Qwen2.5-VL ViT

grid_thw tokens num cudaStreamSynchronize Memcpy DtoH Memcpy HtoD Memcpy HtoD Tx
[[1, 36, 36]] 324 33 times 28 times 5 times 28616 bytes
[[10, 36, 36]] 3240 222 times 217 times 5 times 286124 bytes

this PR Qwen2.5-VL ViT

grid_thw tokens num cudaStreamSynchronize Memcpy DtoH Memcpy HtoD Memcpy HtoD Tx
[[1, 36, 36]] 324 0 times 0 times 1 times 2816 bytes
[[10, 36, 36]] 3240 0 times 0 times 1 times 28016 bytes

Piecewise benchmark

Generating rot_pos_ids for Qwen2/2.5-VL (GPU)

grid_thw tokens num main branch this PR (torch) numba then move to GPU
(currently not used)
[[1, 8, 8]] 16 0.464ms 0.270ms 0.163ms
[[1, 36, 36]] 324 0.490ms 0.274ms 0.170ms
[[10, 36, 36]] 3240 0.507ms 0.274ms 0.170ms
[[10, 36, 36]] * 10 32400 3.132ms 0.277ms 1.134ms
[[10, 36 ± 2, 36 ± 2]] * 10
(different frame sizes handled seperately)
28200 3.028ms 0.432ms 0.982ms

NOTE:

  • For "this PR (torch)", a lot of time we are waiting to allocate the output tensor via torch.empty, maybe we can auto-tune by using the numba impl for smaller batch size?
  • I have tested torch.compile on it, and write a similar triton jit kernel. They seems not faster than the optimized torch impl in this PR.

Generating rot_pos_ids for Qwen2/2.5-VL (CPU)

grid_thw tokens num main branch this PR (numba)
[[1, 8, 8]] 16 0.151ms 0.025ms
[[1, 36, 36]] 324 0.175ms 0.025ms
[[10, 36, 36]] 3240 0.196ms 0.033ms
[[10, 36, 36]] * 10 32400 1.584ms 0.192ms
[[10, 36 ± 2, 36 ± 2]] * 10
(different frame sizes handled seperately)
28200 1.505ms 0.714ms

Generating window_indices and so on for Qwen2.5-VL (GPU)

grid_thw tokens num main branch this PR (numba + torch)
[[1, 8, 8]] 16 1.107ms 0.326ms
[[1, 36, 36]] 324 1.130ms 0.328ms
[[10, 36, 36]] 3240 1.176ms 0.344ms
[[10, 36, 36]] * 10 32400 7.677ms 0.485ms
[[10, 36 ± 2, 36 ± 2]] * 10
(different frame sizes handled seperately)
28200 6.832ms 0.450ms

ViT e2e benchmark

Qwen2-VL ViT (GPU)

grid_thw tokens num main branch this PR
[[1, 36, 36]] 324 54.83ms 53.87ms
[[10, 36, 36]] 3240 484.79ms 481.14ms

Qwen2.5-VL ViT (GPU)

grid_thw tokens num main branch this PR
[[1, 36, 36]] 324 48.28ms 46.72ms
[[10, 36, 36]] 3240 454.65ms 446.90ms

Benchmark script: https://gist.github.com/imkero/590b31e500443a31a07386ae8539e9d8

@imkero imkero marked this pull request as draft March 12, 2025 15:57
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label Mar 12, 2025
@mergify mergify bot added the v1 label Mar 13, 2025
@imkero imkero changed the title [Misc] Optimize Qwen2-VL's M-RoPE pos calc using numba [Perf] Optimize Qwen2/2.5-VL/Omni series' rot pos compute using numba Mar 22, 2025
@imkero imkero force-pushed the feat/qwen2-vl-sched-perf branch from 18aa10d to 8ceb5c9 Compare April 5, 2025 15:56
@imkero imkero changed the title [Perf] Optimize Qwen2/2.5-VL/Omni series' rot pos compute using numba [Perf] Optimize Qwen2/2.5-VL series' rot pos & attn performance Apr 5, 2025
@NickLucche
Copy link
Collaborator

Nice work!
I wonder whether it would've been easier to land this contribution in two separate PRs, one with the CUDA implementation and a follow-up with the numba one for CPU.

@imkero imkero force-pushed the feat/qwen2-vl-sched-perf branch from 97dea30 to 9159e8c Compare April 18, 2025 15:29
@mergify mergify bot added the ci/build label Apr 18, 2025
Signed-off-by: imkero <[email protected]>
@imkero imkero changed the title [Perf] Optimize Qwen2/2.5-VL series' rot pos & attn performance [Perf] Optimize Qwen2/2.5-VL tensor generating performance Apr 18, 2025
@imkero
Copy link
Contributor Author

imkero commented Apr 18, 2025

Nice work! I wonder whether it would've been easier to land this contribution in two separate PRs, one with the CUDA implementation and a follow-up with the numba one for CPU.

I tried all possible solutions I know (numba only, torch + numba, triton) for this PR to find the most efficient approach, so I would like to bring the most efficient solution I found for both CPU and CUDA backend here. And it's quite simple (just adding a dispatch function in front of them)

@imkero imkero changed the title [Perf] Optimize Qwen2/2.5-VL tensor generating performance [Perf] Optimize Qwen2/2.5-VL ViT tensor generating performance Apr 18, 2025
@imkero imkero marked this pull request as ready for review April 18, 2025 16:34
@mergify
Copy link

mergify bot commented Jun 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @imkero.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added needs-rebase rocm Related to AMD ROCm labels Jun 11, 2025
@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
@imkero imkero closed this Jul 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build multi-modality Related to multi-modality (#4194) needs-rebase qwen Related to Qwen models rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants