Skip to content

Conversation

@andylolu2
Copy link
Contributor

@andylolu2 andylolu2 commented Sep 30, 2025

Purpose

Currently when enabling LoRA (with cuda graphs, which is necessary for reasonable speed) adds overhead to the normal inference path, even if there are no active LoRA adapters. This is because we currently only capture cuda graphs with LoRA operations included.

In this PR, I make some small changes to capture a different set of cuda graphs when there are no active LoRA adapters, so we get exactly the same speed as normal inference when there are no active LoRA requests.

Implementation

  • Added a new has_lora attribute to BatchDescriptor.
  • Capture two sets of cuda graphs while capturing.
  • At runtime correctly dispatch to the graphs w/ or w/o LoRA ops based on len(self.input_batch.lora_id_to_lora_request) > 0.
  • Move .zero() of the intermediate lora buffer to inside lora_shrink this way it will be skipped when there's no LoRAs active.

Test Plan

Show that LoRA still functionally works, but has zero-overhead when there's no active LoRAs.

Test Result

--enable-lora overhead reduced from 10.5% to 1.4%. I compared the kernels launched and they are identical to the --no-enable-lora case when there's no active LoRAs, so I suspect the 1.4% overhead is just from additional CPU-side logic.

Baseline

$ vllm bench latency --model meta-llama/Llama-2-7b-hf
Avg latency: 0.8444958463311195 seconds
10% percentile latency: 0.8407110057771205 seconds
25% percentile latency: 0.8423566690180451 seconds
50% percentile latency: 0.8437296429183334 seconds
75% percentile latency: 0.8471306945430115 seconds
90% percentile latency: 0.8484105261275545 seconds
99% percentile latency: 0.851516973322723 seconds

Before PR

$ vllm bench latency --model meta-llama/Llama-2-7b-hf --enable-lora
Avg latency: 0.9335442642603691 seconds
10% percentile latency: 0.9302553807385265 seconds
25% percentile latency: 0.9311426649801433 seconds
50% percentile latency: 0.9327256239484996 seconds
75% percentile latency: 0.9367578914971091 seconds
90% percentile latency: 0.9377115628449246 seconds
99% percentile latency: 0.9390150624723174 seconds

After PR

$ vllm bench latency --model meta-llama/Llama-2-7b-hf --enable-lora
Avg latency: 0.856378769610698 seconds
10% percentile latency: 0.8511904217069969 seconds
25% percentile latency: 0.8531599651905708 seconds
50% percentile latency: 0.8564562875544652 seconds
75% percentile latency: 0.8592897556954995 seconds
90% percentile latency: 0.8603642771951854 seconds
99% percentile latency: 0.8627826083195396 seconds

Functionality test

pytest -svx tests/lora/test_llama_tp.py

still passes (it uses cuda graphs).


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@andylolu2 andylolu2 changed the title Specialized LoRA cuda graph [Feature] Specialized LoRA cuda graph Sep 30, 2025
@mergify mergify bot added the v1 label Sep 30, 2025
@mergify
Copy link

mergify bot commented Sep 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @andylolu2.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 30, 2025
@andylolu2 andylolu2 force-pushed the andy/lora-cuda-graphs-specialization branch from a0ed199 to 00d1328 Compare October 1, 2025 02:36
@mergify mergify bot removed the needs-rebase label Oct 1, 2025
@andylolu2 andylolu2 changed the title [Feature] Specialized LoRA cuda graph [Feature] LoRA cuda graph specialization Oct 1, 2025
@andylolu2 andylolu2 marked this pull request as ready for review October 1, 2025 02:41
@andylolu2
Copy link
Contributor Author

/gemini review

@andylolu2
Copy link
Contributor Author

@gemini-code-assist

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a specialization for CUDA graphs with LoRA to optimize performance. The key idea is to capture two distinct sets of CUDA graphs: one for batches with active LoRA adapters and another for batches without. This avoids the overhead associated with LoRA operations when no LoRA adapters are in use. The changes include adding a has_lora attribute to BatchDescriptor for dispatching, modifying the CUDA graph capture logic to handle both scenarios, and moving the zeroing of the intermediate LoRA buffer to be conditional on LoRA activation. The implementation appears solid and the provided test results demonstrate a significant reduction in overhead, from 10.5% down to 1.4%, when LoRA is enabled but not active. The code changes are consistent and correctly implement the intended optimization. I have not found any critical or high-severity issues.

@gemini-code-assist
Copy link
Contributor

Thank you for submitting this pull request, @andylolu2! I've reviewed the changes and appreciate the clear problem statement and detailed implementation notes.

This feature significantly improves efficiency by specializing CUDA graphs for LoRA and non-LoRA scenarios. The reported reduction in overhead from 10.5% to 1.4% is a substantial gain, directly addressing the stated purpose of minimizing overhead when LoRA is enabled but not actively used.

Key observations from the review:

  • Correctness: The logic for introducing has_lora to BatchDescriptor and using it for CUDA graph dispatching is sound. The conditional zeroing of the LoRA buffer within lora_shrink_op.py is a correct and efficient optimization.
  • Efficiency: This PR is a direct improvement in efficiency. By capturing separate CUDA graphs and optimizing buffer initialization, it effectively eliminates most of the overhead associated with LoRA when no adapters are active.
  • Maintainability: The changes are well-encapsulated within the relevant modules (forward_context.py, lora_shrink_op.py, punica_gpu.py, cudagraph_dispatcher.py, gpu_model_runner.py). The use of itertools.product for generating compilation cases is a clean and Pythonic approach. The inline comment explaining the torch.empty change is also helpful for future maintainers.
  • Security: No security concerns were identified with these changes.

Overall, this is a valuable enhancement that improves the performance characteristics of LoRA integration. Great work!

@andylolu2 andylolu2 changed the title [Feature] LoRA cuda graph specialization [LoRA] LoRA cuda graph specialization Oct 2, 2025
@varun-sundar-rabindranath
Copy link
Contributor

Thanks @andylolu2 .
Left some comments regarding some refactors.

cc @jeejeelee
cc @ProExpertProg @LucasWilkinson for cudagraph dispatching changes

self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(num_tokens=bs, uniform_decode=False))
for has_lora in [True, False]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ:Will this increase the memory consumption of the CUDA graph?

Copy link
Contributor Author

@andylolu2 andylolu2 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it does, but not by too much.

Before PR: (with --enable-lora)

Free memory ... 1.62 GiB for CUDAGraph memory.

After PR: (with --enable-lora)

Free memory ... 2.38 GiB for CUDAGraph memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andylolu2 can you also try this with just the base model (i.e. not enabling lora) to see that it doesn't affect the CUDA graph memory ? Thanks.

Copy link
Contributor Author

@andylolu2 andylolu2 Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Baseline is: (without --enable-lora)

Free memory ... 1.14 GiB for CUDAGraph memory.

@andylolu2 andylolu2 force-pushed the andy/lora-cuda-graphs-specialization branch from 00d1328 to b027fd2 Compare October 5, 2025 18:12
Signed-off-by: Andy Lo <[email protected]>
@ProExpertProg
Copy link
Collaborator

Cc @fhl2000 can you take a look?

Copy link
Contributor

@fhl2000 fhl2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the cudagraph dispatching stuff. Only some small thoughts on the Lora path.

Signed-off-by: Andy Lo <[email protected]>
Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @ProExpertProg please take another look

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 16, 2025
@jeejeelee jeejeelee enabled auto-merge (squash) October 16, 2025 14:49
@dcmaddix dcmaddix mentioned this pull request Oct 18, 2025
4 tasks
@andylolu2
Copy link
Contributor Author

@ProExpertProg @jeejeelee I see CI failures but seemingly unrelated. I scanned through them and all failures seem to be caused by

 RuntimeError: _moe_C::topk_softmax() is missing value for argument 'renormalize'. Declaration: _moe_C::topk_softmax(Tensor($0! -> ) topk_weights, Tensor($1! -> ) topk_indices, Tensor($2! -> ) token_expert_indices, Tensor gating_output, bool renormalize) -> ()

@jeejeelee
Copy link
Collaborator

Let me sync with main and test again

@jeejeelee jeejeelee merged commit b63f214 into vllm-project:main Oct 20, 2025
51 checks passed
Ther-LF pushed a commit to Ther-LF/vllm that referenced this pull request Oct 20, 2025
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Oct 20, 2025
Merged 8 commits from origin/main including:
- PR vllm-project#26586: Eagle rejection sampler fix (previously cherry-picked)
- LoRA CUDA graph specialization (vllm-project#25914)
- Bee-8B VLM model support (vllm-project#27012)
- Utilities reorganization (network_utils, async_utils, etc.)
- Multiple bug fixes and improvements

In-Tree Modifications:
- Removed Eagle rejection sampler cherry-pick (now in upstream)
- Kept Qwen3 tool parser fix (still needed, line 523)
- Only 1 active in-tree modification remaining

Plugin Compatibility:
- All 10 plugin patches load successfully
- No target class changes required
- Clean merge with no conflicts

Documentation Updates:
- Updated IN_TREE_MODIFICATIONS.md (moved Eagle fix to Removed/Obsolete)
- Updated CLAUDE.md merge history
- Verified clean diff with origin/main (3 files, all documented)

Signed-off-by: Pradyun Ramadorai <[email protected]>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
adabeyta pushed a commit to adabeyta/vllm that referenced this pull request Oct 20, 2025
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
Signed-off-by: Andy Lo <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Signed-off-by: Alberto Perdomo <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Andy Lo <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Andy Lo <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Signed-off-by: 0xrushi <[email protected]>
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants