Skip to content

Comments

[Feature] Diffusion LoRA Adapter Support (PEFT compatible) for vLLM alignment#758

Merged
david6666666 merged 65 commits intovllm-project:mainfrom
AndyZhou952:peft_lora
Jan 27, 2026
Merged

[Feature] Diffusion LoRA Adapter Support (PEFT compatible) for vLLM alignment#758
david6666666 merged 65 commits intovllm-project:mainfrom
AndyZhou952:peft_lora

Conversation

@AndyZhou952
Copy link
Contributor

@AndyZhou952 AndyZhou952 commented Jan 13, 2026

This is a joint work by @AndyZhou952 and @dongbo910220.

Design doc here.

Purpose:

Following Issue #281 and PR #657, this PR adds diffusion LoRA Adapter Support (PEFT compatible) for vLLM alignment.

This PR reuses the LoRA RPC logic from #657 (thanks to @dongbo910220's implementation), while utilizing the vLLM self-defined layer for LoRA support and PEFT format to be incorporated in the reinforcement training pipeline.

How vLLM adds LoRA support:

Three steps: (1) initialization; (2) per-request; (3) inference (via vLLM self-defined LoRA layers, calculations in forward).

image image

vLLM-Omni PEFT LoRA integration Design:

image

Besides add_lora and remove_lora, we also support pin_lora and list_lora as public APIs to be consistent with the vLLM base behavior.

Design principles:

  • Preserve vLLM variable and function naming convention; try to keep the functionalities consistent with vLLM
  • Reuse helper functions from vLLM whenever possible, add minimum new code for diffusion-specific wrappers
  • Extensibility: easy to add support for more layer adapters and for multiple LoRA loading

Design choices:

  • Support PEFT LoRA format and reuse vLLM LoRA layers (to be consistent with/ vLLM behavior). We reuse PEFTHelper from vLLM, which will look for the file adapter_config.json when loading LoRAs.
  • DiffusionLoRAManager does not inherit from LoRAManager, since (1) LoRAManager is LLM-centric with redundant variables in __init__ for diffusion models; (2) diffusion component-based nature requires separate treatment;
  • Incorporate LRU cache management within DiffusionLoRAManager to keep things compact, also no need for a separate WorkerLoRAManager as vllm-omni has gpu_worker that does the job. Also, LRU cache management is kept within DiffusionLoRAManager.
  • In vLLM BaseLinearLayerWithLoRA, the calculation is done in self.punica_wrapper.add_linear_layer(). Note that punica_wrapper is used for multiple LoRA management. In most of the diffusion use cases, having one LoRA would be sufficient. One issue in the vLLM's BaseLinearLayerWithLoRA is that it is too closely tied to punica_wrapper, while the current implementation for punica_wrapper is not really suitable for the diffusion use case. As a temporary workaround, we define class DiffusionBaseLinearWithLoRA(BaseLinearLayerWithLoRA) and rewrite apply (where the LoRA calculation happens) to eliminate the dependency on punica_wrapper.

Functions/variables/classes reused from vLLM:

  • LoRARequest request structure
  • get_supported_lora_modules, get_adapter_absolute_path, PEFTHelper.from_local_dir, LoRAModel.from_local_checkpoint, LoRALayerWeights.optimize for scaling and use in _load_adapter.
  • LoRAConfig, from_layer, replace_submodule, BaseLayerWithLoRA.set_mapping in _replace_layers_with_lora to substitute with vLLM self-defined LoRA layers.

Current limitations:

  • Currently, can only load one LoRA adapter at a time for one batch

Test Plan:

Start server with SD3.5 and LoRA:
python -m vllm_omni.entrypoints.openai.api_server
--model stabilityai/stable-diffusion-3.5-medium
--lora-dirs /path/to/lora-test
Send request with LoRA:
curl -X POST http://localhost:8000/v1/images/generations
-H "Content-Type: application/json"
-d '{"prompt": "A whimsical hand-drawn animation still of a small countryside train station at sunset,
warm golden light, lush greenery, soft watercolor textures, highly detailed, sharp focus", "lora": {"name": "rafadan", "local_path": "/path/to/lora.safetensors", "scale": 1.0}}'

Test Result:

No LoRA LoRA
image 1 image 2

1024x1024, steps=30, seed=42

1024x1024, steps=30, seed=42, lora scale = 1.0

Co-authored-by: dongbo910220


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing the test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: AndyZhou952 <[email protected]>
Signed-off-by: AndyZhou952 <[email protected]>
Signed-off-by: AndyZhou952 <[email protected]>
Co-authored-by: dongbo910220 <[email protected]>

Signed-off-by: AndyZhou952 <[email protected]>
Signed-off-by: AndyZhou952 <[email protected]>
@knlnguyen1802
Copy link
Contributor

@AndyZhou952 Thank you for the work.
I think for RLHF in verl, it need to support load and remove static lora (as after lora weight update it will remove old lora and add new lora weight).
So I think vllm-omni also need to support it.
I think current design only support deactivate dynamic lora.

@AndyZhou952
Copy link
Contributor Author

AndyZhou952 commented Jan 13, 2026

@AndyZhou952 Thank you for the work. I think for RLHF in verl, it need to support load and remove static lora (as after lora weight update it will remove old lora and add new lora weight). So I think vllm-omni also need to support it. I think current design only support deactivate dynamic lora.

Thanks for your interest! For this part I think we can keep it consistent with the base vLLM design to add add_lora and remove_lora in the public API.

Also I think it makes sense to unite the static/dynamic support. Essentially, static support means to load lora weights at the very start, while everything else remains the same as lora support.

Will update the design workflow and code base shortly to reflect the changes.

update 01/13 design & code updated to reflect the changes above. PTAL.

dongbo910220 and others added 4 commits January 13, 2026 19:38
Fix diffusion weight index path for subfolders

Signed-off-by: Andy Zhou <[email protected]>
Signed-off-by: AndyZhou952 <[email protected]>
Signed-off-by: AndyZhou952 <[email protected]>
@SamitHuang SamitHuang added the RL Related to Reinforcement Learning label Jan 14, 2026
return self.pipeline.load_weights(weights)

def remove_lora(self, adapter_id: int) -> bool:
return self.lora_manager.remove_adapter(adapter_id) if self.lora_manager else False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to guarantee self.lora_manager is not None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed here & removed condition checks. Thanks

max_num_batched_tokens=max_num_batched_tokens,
max_batches=1, # single request
device=self.device,
max_loras=1, # single lora

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ:

  • Do SD models generally only have one concurrent LoRA?
  • Do we need to develop a new punicarapper for SD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For diffusion, in most scenarios, it suffices to use only 1 LoRA. May leave that to future work
  • Since the punica wrapper is mostly used for multiple LoRA management, I think we can leave this out for this PR. Though one caveat is that punica_wrapper is too closely tied with BaseLinearLayerWithLoRA (handles LoRA calculation, and the reason why we init punica_wrapper in the first place) but punica_wrapper has quite a few LLM-specific components.

Please check this implementation (commit 955a2cf) and see if this makes sense. TL;DR we still inherit from BaseLinearLayerWithLoRA in DiffusionBaseLayerWithLoRA but rewrite the apply function to eliminate the need of using punica_wrapper. The current diffusion linear layer design inherits both DiffusionBaseLayerWithLoRA and the self-defined layers in vLLM.

For a better design, probably can decouple punica_wrapper and BaseLinearLayerWithLoRA. But this can work as a temporary solution for now. Let me know what you think. Thanks

Copy link
Contributor

@dongbo910220 dongbo910220 Jan 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SD users do sometimes stack multiple LoRAs. For this initial PR, we intentionally support a single active LoRA per diffusion execution (max_loras=1). We do allow multiple adapters to be cached on CPU and swapped per request. Also note that our current diffusion runner effectively executes one request per model execution (no cross-request batching), so max_loras=1 matches the existing execution semantics; when multiple requests are passed in one call, we require the same LoRARequest and lora_scale across the batch to avoid silently applying the wrong adapter. Multi‑LoRA composition (weighted mixing, per-sample different adapters, etc.) would require an explicit API for multiple adapters + weights and a more complex kernel/memory-management path, so we’d prefer to follow up in a separate PR if/when needed.

Re punica_wrapper: we don’t introduce a diffusion-specific punica wrapper here. We still inherit from vLLM’s BaseLinearLayerWithLoRA for weight/buffer management, but in DiffusionBaseLinearLayerWithLoRA we override apply() to compute the single‑LoRA delta via direct matmuls (same shrink+expand semantics as Punica) and handle packed projections per-slice (e.g. fused QKV), avoiding the LLM-specific dependencies in punica_wrapper.


if static_lora_path is not None:
logger.info("Loading static LoRA from %s with scale %.2f", static_lora_path, static_lora_scale)
static_request = LoRARequest(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: what's static_lora?

Copy link
Contributor Author

@AndyZhou952 AndyZhou952 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a follow-up from @SamitHuang's comment under #657 to support both static/dynamic LoRA. Static here means to load the LoRA during the init stage when providing the path in od_config.

I suppose now that this PR unites the processing flow of static/dynamic LoRA (all via LoRARequest), can probably update the variable naming here as well to avoid confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a follow-up from @SamitHuang's comment under #657 to support both static/dynamic LoRA. Static here means to load the LoRA during the init stage when providing the path.

I suppose now that this PR unites the processing flow of static/dynamic LoRA (all via LoRARequest), can probably update the variable naming here as well to avoid confusion.

update 01/16: updated variable naming for clarity.

lora_request: LoRARequest,
) -> tuple[LoRAModel, PEFTHelper]:

supported_lora_modules = set(get_supported_lora_modules(self.pipeline))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does SD have any special layers that need to support LoRA? Can get_supported_lora_modules in vLLM be used directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can see _expand_expected_modules_for_merged_projections in L40 (called L163) to handle additional cases like add_kv_proj, to_qkv.

This PR has only been tested on SD. Might still need to investigate to see if we need to further expand to support other models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. We do reuse vLLM’s get_supported_lora_modules() as the baseline, but we snapshot it before injecting LoRA wrappers — after replacement the original LinearBase modules live under .base_layer, which would make the helper return base_layer and break adapter matching across reloads.

For diffusion/SD we also need to cover merged/packed projections (e.g. to_qkv, add_kv_proj), so we expand the expected module set via _expand_expected_modules_for_merged_projections() and treat packed projections as multi-slice when replacing/activating. This has been validated on SD; we can extend the expansion map as we add more pipelines.

Signed-off-by: AndyZhou952 <[email protected]>
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm_omni.diffusion.lora.manager import DiffusionLoRAManager
Copy link

@zhtmike zhtmike Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps to add LoRARequest in lora/request.py for external package import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done - thanks for the suggestion. This makes sense since LoRARequest is a user-facing class. Now can use from vllm_omni.lora.request import LoRARequest.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other scripts using the LoRARequest from vllm, such as input_prcessor.py, async_omni.py and serving_chat.py, may unify them as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other scripts using the LoRARequest from vllm, such as input_prcessor.py, async_omni.py and serving_chat.py, may unify them as well

done - thanks for the observation

Also per discussion, added in-house LoRAConfig within vllm-omni as well.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in config/__init__.py, better add from vllm_omni.config.lora import LoRAConfig

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in config/__init__.py, better add from vllm_omni.config.lora import LoRAConfig

done

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks nice :)

Signed-off-by: AndyZhou952 <[email protected]>
Signed-off-by: AndyZhou952 <[email protected]>
Signed-off-by: AndyZhou952 <[email protected]>
Peft lora wrapper

Signed-off-by: Andy Zhou <[email protected]>
@david6666666 david6666666 added this to the v0.14.0 milestone Jan 23, 2026
@david6666666 david6666666 added the high priority high priority issue, needs to be done asap label Jan 23, 2026
@AndyZhou952 AndyZhou952 marked this pull request as ready for review January 23, 2026 07:15
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 950e388e96

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@AndyZhou952
Copy link
Contributor Author

AndyZhou952 commented Jan 23, 2026

PR is ready @SamitHuang, thanks :-)

@ZJY0516 ZJY0516 requested a review from jeejeelee January 25, 2026 09:15
import pytest
import torch

pytest.importorskip("flash_attn", reason="flash_attn is not installed")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this. I'll enable this test later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@@ -0,0 +1,376 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. Added

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test for this file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added unit tests in tests/diffusion/lora/test_base_linear.py (multi-slice apply, reset fast-path to skip matmuls when inactive, and inactive-slice behavior).

# Known packed projections: accept their separate counterparts.
packed_expansions: dict[str, list[str]] = {
# diffusion: fused QKV
"to_qkv": ["to_q", "to_k", "to_v"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need these hard coding here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hard-coded packed→submodule expansion is needed to support diffusion models with fused projections (e.g. to_qkv, w13) while many PEFT/diffusers LoRA checkpoints are saved against the logical sub-projections (e.g. to_q/to_k/to_v, w1/w3).
We pass expected_lora_modules into LoRAModel.from_local_checkpoint to filter loaded weights; without expanding these names, those submodule keys would be dropped at load time and the LoRA would never be applied. The mapping is intentionally small and only takes effect when the packed module exists in the model, so the impact is contained.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that this solution won't scale well or remain transparent when we encounter a new packed layer in a future model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored this to follow vLLM’s packed_modules_mapping pattern: the packed→sublayer mapping now lives with each diffusion transformer implementation (e.g. to_qkv -> [to_q, to_k, to_v], add_kv_proj -> [...], w13 -> [w1, w3]), instead of being hard-coded in the LoRA framework. DiffusionLoRAManager collects packed_modules_mapping from the pipeline modules at init and uses it to:

  1. expand expected_lora_modules so LoRA keys saved against sub-projections are not dropped at load time, and
  2. map per-sublayer LoRA weights onto packed LoRA layers during target-module matching.

This makes new packed layers explicit and transparent: adding support is done next to the model code (similar to how we already maintain stacked_params_mapping in load_weights()),without changing LoRA core logic.

@SamitHuang
Copy link
Collaborator

@AndyZhou952 Final question, can we also add test results on QwenImageLightning? since it's a typical timestep distilled lora model and is important in Q1 roadmap. https://huggingface.co/lightx2v/Qwen-Image-Lightning

@AndyZhou952
Copy link
Contributor Author

@AndyZhou952 Final question, can we also add test results on QwenImageLightning? since it's a typical timestep distilled lora model and is important in Q1 roadmap. https://huggingface.co/lightx2v/Qwen-Image-Lightning

Currently, the LoRA support in vLLM-Omni only supports peft loading (with adapter_config.json). The loading for Qwen-Image-Lightning (distilled style LoRA) is not quite straightforward based on the current design. This may require manual inspection of the safetensors on the fly, and the exact way to implement this requires further inspection.

I think we can consider making this available in a separate PR if such support is needed (quite a bit of refactoring may be needed). We keep the behavior consistent with base vLLM for this PR for now.

@SamitHuang
Copy link
Collaborator

@AndyZhou952 Final question, can we also add test results on QwenImageLightning? since it's a typical timestep distilled lora model and is important in Q1 roadmap. https://huggingface.co/lightx2v/Qwen-Image-Lightning

Currently, the LoRA support in vLLM-Omni only supports peft loading (with adapter_config.json). The loading for Qwen-Image-Lightning (distilled style LoRA) is not quite straightforward based on the current design. This may require manual inspection of the safetensors on the fly, and the exact way to implement this requires further inspection.

I think we can consider making this available in a separate PR if such support is needed (quite a bit of refactoring may be needed). We keep the behavior consistent with base vLLM for this PR for now.

Okay, let's support it in the next PR.

@SamitHuang SamitHuang added the ready label to trigger buildkite CI label Jan 26, 2026
@david6666666
Copy link
Collaborator

LTGM, Thanks for the contribution

@david6666666 david6666666 merged commit 5037af1 into vllm-project:main Jan 27, 2026
7 checks passed
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is to delete only the newly added line, as opposed to removing the entire file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #972 for the reversion

# -- typically a transformer layer
# used for torch compile optimizations
_repeated_blocks = ["QwenImageTransformerBlock"]
packed_modules_mapping = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to put it here, because we also have something like this in load_weights function for every model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open a new PR to do it.

nussejzz pushed a commit to nussejzz/vllm-omni that referenced this pull request Jan 27, 2026
…lignment (vllm-project#758)

Signed-off-by: AndyZhou952 <[email protected]>
Signed-off-by: Andy Zhou <[email protected]>
Signed-off-by: dongbo910220 <[email protected]>
Signed-off-by: Andy Zhou <[email protected]>
Signed-off-by: Samit <[email protected]>
Co-authored-by: dongbo910220 <[email protected]>
Co-authored-by: Samit <[email protected]>
Signed-off-by: jzz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

high priority high priority issue, needs to be done asap ready label to trigger buildkite CI RL Related to Reinforcement Learning

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants