Skip to content

Conversation

@piood
Copy link
Contributor

@piood piood commented Sep 19, 2025

What does this PR do?

Fix issue where VLLM would only load base model parameters and not LoRA parameters when VLLM_SLEEP_LEVEL == 2 and not using layered_summon.

This fixes the LoRA trainer error where the first rollout would only use base model parameters, and subsequent rollouts would correctly load LoRA parameters.

Fixes: #3516
Related PR: #3461

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where LoRA weights were not correctly loaded during the first rollout when using vLLM with VLLM_SLEEP_LEVEL=2 and without layered_summon. The fix correctly identifies this scenario and forces a full model weight synchronization on the initial rollout by setting base_sync_done to False. Subsequent rollouts will then correctly use the more efficient LoRA-only weight updates. The change also simplifies the logic by removing the now-redundant force_reload flag. The implementation is clean, targeted, and effectively resolves the described issue.

@techkang
Copy link
Collaborator

Have you tried to train on your patch with layer_summon=False? When VLLM_SLEEP_LEVEL=2, the base model's weights are destroyed every iteration. After this patch, I think it cannot load weights again. A simple way to fix this is let VLLM_SLEEP_LEVEL=1 whenever lora is enabled.

@piood
Copy link
Contributor Author

piood commented Sep 19, 2025

Have you tried to train on your patch with layer_summon=False? When VLLM_SLEEP_LEVEL=2, the base model's weights are destroyed every iteration. After this patch, I think it cannot load weights again. A simple way to fix this is let VLLM_SLEEP_LEVEL=1

You are right.There are some problem, i will continue fix and test it.

@WncFht
Copy link

WncFht commented Sep 19, 2025

@techkang I think the simple way you mentioned maybe not that great. In that way, all the lora situation should set VLLM_SLEEP_LEVEL to be 1, which will make the lora mode cannot offload the base model and potentially when the rollout.load_format equals to dummy, it will never load base model.

So I think we may implement a loading strategy to first load base model and then load the lora adapter in every loading turn. If you agree with me, I will comlete the PR recently, since it may bring more change in the params collecting and model updating logic.

@techkang
Copy link
Collaborator

@WncFht Thanks for you discussion. But I still think when rollout.load_format equals to dummy, the base model will be loaded at the first iteration with these code:

else:
model = peft_model.base_model.model
orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name()
model = model.to("cpu")
for name, param in model.state_dict().items():
if any(x in name for x in ["_flat_param", "lora_"]):
continue
name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "")
lora_params[name] = (
param.full_tensor().detach().cpu()
if hasattr(param, "full_tensor")
else param.detach().cpu()
)
model = model.to(orig_dev)

@piood
Copy link
Contributor Author

piood commented Sep 20, 2025

@techkang @WncFht I have polished this PR and tested it. It now correctly loads the base model parameters before the LoRA parameters in each iteration when vllm_sleep_level==2 and layer_summon=False.

Copy link
Collaborator

@techkang techkang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piood Thanks, LGTM! Left some minor suggestions.

@techkang
Copy link
Collaborator

cc @wuxibin89

@piood
Copy link
Contributor Author

piood commented Sep 21, 2025

@techkang I found an issue when LoRA is enabled with vllm_sleep_level=2: when using dummy load format (base_sync_done=False), both params and base_model_params in rollout_mode contain the same base model parameters, causing duplication.

@techkang
Copy link
Collaborator

techkang commented Sep 21, 2025 via email

@piood
Copy link
Contributor Author

piood commented Sep 21, 2025

@techkang Got it, thanks!

…_level=2

Fix issue where VLLM would only load base model parameters and not LoRA
parameters when VLLM_SLEEP_LEVEL == 2 and not using layered_summon.

This fixes the LoRA trainer error where the first rollout would only use
base model parameters, and subsequent rollouts would correctly load LoRA
parameters.

Fixes: volcengine#3516
Related PR: volcengine#3461
@piood piood force-pushed the fix/vllm-lora-weights-loading branch from 075f9f3 to 5c0948d Compare September 21, 2025 16:10
@techkang techkang self-requested a review September 22, 2025 03:31
@vermouth1992 vermouth1992 merged commit 96e7071 into volcengine:main Sep 22, 2025
35 of 38 checks passed
masoudhashemi pushed a commit to masoudhashemi/verl that referenced this pull request Oct 19, 2025
…_level=2 and without using layerd_summon (volcengine#3541)

### What does this PR do?

Fix issue where VLLM would only load base model parameters and not LoRA
parameters when VLLM_SLEEP_LEVEL == 2 and not using layered_summon.

This fixes the LoRA trainer error where the first rollout would only use
base model parameters, and subsequent rollouts would correctly load LoRA
parameters.

Fixes: volcengine#3516
Related PR: volcengine#3461

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
techkang pushed a commit to techkang/verl that referenced this pull request Oct 31, 2025
…_level=2 and without using layerd_summon (volcengine#3541)

### What does this PR do?

Fix issue where VLLM would only load base model parameters and not LoRA
parameters when VLLM_SLEEP_LEVEL == 2 and not using layered_summon.

This fixes the LoRA trainer error where the first rollout would only use
base model parameters, and subsequent rollouts would correctly load LoRA
parameters.

Fixes: volcengine#3516
Related PR: volcengine#3461

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
mtian8 pushed a commit to mtian8/verl that referenced this pull request Nov 1, 2025
…_level=2 and without using layerd_summon (volcengine#3541)

### What does this PR do?

Fix issue where VLLM would only load base model parameters and not LoRA
parameters when VLLM_SLEEP_LEVEL == 2 and not using layered_summon.

This fixes the LoRA trainer error where the first rollout would only use
base model parameters, and subsequent rollouts would correctly load LoRA
parameters.

Fixes: volcengine#3516
Related PR: volcengine#3461

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
wangboxiong320 pushed a commit to wangboxiong320/verl that referenced this pull request Nov 1, 2025
…_level=2 and without using layerd_summon (volcengine#3541)

### What does this PR do?

Fix issue where VLLM would only load base model parameters and not LoRA
parameters when VLLM_SLEEP_LEVEL == 2 and not using layered_summon.

This fixes the LoRA trainer error where the first rollout would only use
base model parameters, and subsequent rollouts would correctly load LoRA
parameters.

Fixes: volcengine#3516
Related PR: volcengine#3461

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants