-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[rollout,vllm] fix: Add LoRA Loading to Async vLLM #3639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[rollout,vllm] fix: Add LoRA Loading to Async vLLM #3639
Conversation
|
Hi, I’m running GRPO + vLLM + Qwen3-8B + tool_agent_loop training base on the old code and hit two consecutive errors. Thanks in advance! Below is my exact training config: ulimit -n 65535 PROJECT_DIR="$(pwd)" TRAIN_DATA="/data/verl_train_search10times.parquet" TOOL_CONFIG="$CONFIG_PATH/tool_config/codebase_search_tool_config.yaml" export VLLM_ATTENTION_BACKEND=XFORMERS nohup python3 -m verl.trainer.main_ppo |
### What does this PR do?
Currently, async vLLM with AgentWorkerLoop throws an error when
`update_weights` with LoRA weights. This expands support for
AgentWorkerLoop with LoRAs.
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
> Demonstrate the high-level design if this PR is complex, and list the
specific changes.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. The previous #3639 addressed the **crashing issues** in `update_weights` of `vLLMAsyncRollout`. However, experiments (see **Tests** below) reveal an implicit **off-policy issue**: the rollout generation still uses the **base model** instead of the updated **LoRA model**, resulting in degraded performance. We traced this to a bug in `vllm_async_server.vLLMHttpServerBase` causing a mismatch between LoRA updates and rollout generation. Specifically: * In `vLLMAsyncRollout`, `update_weights` correctly updates LoRA weights from the FSDP actor to the rollout `AsyncLLM` engine. However, the updated adapter is assigned a random `lora_name` and `lora_int_id` (generated from `time.ns()`), which are not stored—making them hard to reuse. https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L595-L604 * During rollout generation, the newly added LoRA adapter is **never used** due to two issues: 1. The `vllm_config` used to create `AsyncLLM` lacks a `LoRAConfig` (e.g., `max_lora_rank`), so `AsyncLLM` is not prepared for LoRA-based generation requests. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L299-L304 2. When calling `generate` in `vLLMHttpServerBase`, the request to `self.engine` (the `AsyncLLM` instance) **omits any `LoRARequest`**, meaning generation always uses the base model. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L360 #### Proposed Fixes in this PR * Standardize and persist `VLLM_LORA_INT_ID` and `VLLM_LORA_NAME` across the training process to consistently locate and apply updated LoRA weights. * Inject `LoRAConfig` during `AsyncLLM` initialization and ensure `vLLMHttpServerBase` passes a proper `LoRARequest` (identified via `VLLM_LORA_NAME`) during rollout generation. * Add utility methods to automatically validate and set `max_lora_rank` in vLLM from `config.actor_rollout_ref.model.lora_rank`, addressing issues like #3696 #### Remarks Special thanks to @sanxing-chen for inspiring this fix with his prior patches. Also his PR #3765 -- while also tackling an issue hurting LoRA performance -- seems to be orthogonal to the issues addressed here. ### Checklist Before Starting * [x] Search for similar PRs. Paste at least one query link here: #3639 #3765 * [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) * `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` * If this PR involves multiple modules, separate them with `,`, e.g., `[megatron, fsdp, doc]` * `{type}` ∈ {`feat`, `fix`, `refactor`, `chore`, `test`} * If this PR breaks any API, prepend `[BREAKING]` to the title. * Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that cannot be tested by CI (e.g., algorithm implementation, new model support), validate with experiments and include results such as training curves or evaluation metrics. Controlled experiments based on `examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh` (see [adapted script](https://gist.github.com/listar2000/43bb0e1d6f0d3c2503922ca2bfee0a6b)) **clearly demonstrate both the issue and the effectiveness of the fix**. <img width="2528" height="1328" alt="kl-loss" src="https://github.com/user-attachments/assets/008cdace-fc6d-459a-8493-8ddb440c57ec" /> <img width="2528" height="1328" alt="val-reward" src="https://github.com/user-attachments/assets/aa2e13c7-25cc-41cd-a916-d98f134060e6" /> See the full [W&B training log](https://wandb.ai/listar2000/verl-latest-lora). Summary: * **sync-lora-32** — baseline (synchronous mode). * **async-lora-32-before-fix** — async LoRA on `main` branch, showing degraded performance. * **async-lora-32-no-remove** — ablation variant with fixes applied **but without removing old LoRA adapters** between updates (showing the importance of removal). * **async-lora-32-after-fix** — full fix applied, achieving expected improvement. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). **Not Applicable** - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: **This PR can hardly be covered by regular CI. I instead run concrete experiments with GSM8K dataset.** - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
### What does this PR do?
Currently, async vLLM with AgentWorkerLoop throws an error when
`update_weights` with LoRA weights. This expands support for
AgentWorkerLoop with LoRAs.
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
> Demonstrate the high-level design if this PR is complex, and list the
specific changes.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do?
Currently, async vLLM with AgentWorkerLoop throws an error when
`update_weights` with LoRA weights. This expands support for
AgentWorkerLoop with LoRAs.
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
> Demonstrate the high-level design if this PR is complex, and list the
specific changes.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do?
Currently, async vLLM with AgentWorkerLoop throws an error when
`update_weights` with LoRA weights. This expands support for
AgentWorkerLoop with LoRAs.
### Checklist Before Starting
- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
- `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
- Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`
### Test
> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
### API and Usage Example
> Demonstrate how the API changes if any, and provide usage example(s)
if possible.
```python
# Add code snippet or script demonstrating how to use this
```
### Design & Code Changes
> Demonstrate the high-level design if this PR is complex, and list the
specific changes.
### Checklist Before Submitting
> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
…ine#3821) ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. The previous volcengine#3639 addressed the **crashing issues** in `update_weights` of `vLLMAsyncRollout`. However, experiments (see **Tests** below) reveal an implicit **off-policy issue**: the rollout generation still uses the **base model** instead of the updated **LoRA model**, resulting in degraded performance. We traced this to a bug in `vllm_async_server.vLLMHttpServerBase` causing a mismatch between LoRA updates and rollout generation. Specifically: * In `vLLMAsyncRollout`, `update_weights` correctly updates LoRA weights from the FSDP actor to the rollout `AsyncLLM` engine. However, the updated adapter is assigned a random `lora_name` and `lora_int_id` (generated from `time.ns()`), which are not stored—making them hard to reuse. https://github.com/volcengine/verl/blob/e94366d46a027d38e48e3a859b745387f131b0ad/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L595-L604 * During rollout generation, the newly added LoRA adapter is **never used** due to two issues: 1. The `vllm_config` used to create `AsyncLLM` lacks a `LoRAConfig` (e.g., `max_lora_rank`), so `AsyncLLM` is not prepared for LoRA-based generation requests. See https://github.com/volcengine/verl/blob/e94366d46a027d38e48e3a859b745387f131b0ad/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L299-L304 2. When calling `generate` in `vLLMHttpServerBase`, the request to `self.engine` (the `AsyncLLM` instance) **omits any `LoRARequest`**, meaning generation always uses the base model. See https://github.com/volcengine/verl/blob/e94366d46a027d38e48e3a859b745387f131b0ad/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L360 #### Proposed Fixes in this PR * Standardize and persist `VLLM_LORA_INT_ID` and `VLLM_LORA_NAME` across the training process to consistently locate and apply updated LoRA weights. * Inject `LoRAConfig` during `AsyncLLM` initialization and ensure `vLLMHttpServerBase` passes a proper `LoRARequest` (identified via `VLLM_LORA_NAME`) during rollout generation. * Add utility methods to automatically validate and set `max_lora_rank` in vLLM from `config.actor_rollout_ref.model.lora_rank`, addressing issues like volcengine#3696 #### Remarks Special thanks to @sanxing-chen for inspiring this fix with his prior patches. Also his PR volcengine#3765 -- while also tackling an issue hurting LoRA performance -- seems to be orthogonal to the issues addressed here. ### Checklist Before Starting * [x] Search for similar PRs. Paste at least one query link here: volcengine#3639 volcengine#3765 * [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) * `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` * If this PR involves multiple modules, separate them with `,`, e.g., `[megatron, fsdp, doc]` * `{type}` ∈ {`feat`, `fix`, `refactor`, `chore`, `test`} * If this PR breaks any API, prepend `[BREAKING]` to the title. * Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that cannot be tested by CI (e.g., algorithm implementation, new model support), validate with experiments and include results such as training curves or evaluation metrics. Controlled experiments based on `examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh` (see [adapted script](https://gist.github.com/listar2000/43bb0e1d6f0d3c2503922ca2bfee0a6b)) **clearly demonstrate both the issue and the effectiveness of the fix**. <img width="2528" height="1328" alt="kl-loss" src="https://github.com/user-attachments/assets/008cdace-fc6d-459a-8493-8ddb440c57ec" /> <img width="2528" height="1328" alt="val-reward" src="https://github.com/user-attachments/assets/aa2e13c7-25cc-41cd-a916-d98f134060e6" /> See the full [W&B training log](https://wandb.ai/listar2000/verl-latest-lora). Summary: * **sync-lora-32** — baseline (synchronous mode). * **async-lora-32-before-fix** — async LoRA on `main` branch, showing degraded performance. * **async-lora-32-no-remove** — ablation variant with fixes applied **but without removing old LoRA adapters** between updates (showing the importance of removal). * **async-lora-32-after-fix** — full fix applied, achieving expected improvement. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). **Not Applicable** - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: **This PR can hardly be covered by regular CI. I instead run concrete experiments with GSM8K dataset.** - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…ine#3821) ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. The previous volcengine#3639 addressed the **crashing issues** in `update_weights` of `vLLMAsyncRollout`. However, experiments (see **Tests** below) reveal an implicit **off-policy issue**: the rollout generation still uses the **base model** instead of the updated **LoRA model**, resulting in degraded performance. We traced this to a bug in `vllm_async_server.vLLMHttpServerBase` causing a mismatch between LoRA updates and rollout generation. Specifically: * In `vLLMAsyncRollout`, `update_weights` correctly updates LoRA weights from the FSDP actor to the rollout `AsyncLLM` engine. However, the updated adapter is assigned a random `lora_name` and `lora_int_id` (generated from `time.ns()`), which are not stored—making them hard to reuse. https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L595-L604 * During rollout generation, the newly added LoRA adapter is **never used** due to two issues: 1. The `vllm_config` used to create `AsyncLLM` lacks a `LoRAConfig` (e.g., `max_lora_rank`), so `AsyncLLM` is not prepared for LoRA-based generation requests. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L299-L304 2. When calling `generate` in `vLLMHttpServerBase`, the request to `self.engine` (the `AsyncLLM` instance) **omits any `LoRARequest`**, meaning generation always uses the base model. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L360 #### Proposed Fixes in this PR * Standardize and persist `VLLM_LORA_INT_ID` and `VLLM_LORA_NAME` across the training process to consistently locate and apply updated LoRA weights. * Inject `LoRAConfig` during `AsyncLLM` initialization and ensure `vLLMHttpServerBase` passes a proper `LoRARequest` (identified via `VLLM_LORA_NAME`) during rollout generation. * Add utility methods to automatically validate and set `max_lora_rank` in vLLM from `config.actor_rollout_ref.model.lora_rank`, addressing issues like volcengine#3696 #### Remarks Special thanks to @sanxing-chen for inspiring this fix with his prior patches. Also his PR volcengine#3765 -- while also tackling an issue hurting LoRA performance -- seems to be orthogonal to the issues addressed here. ### Checklist Before Starting * [x] Search for similar PRs. Paste at least one query link here: volcengine#3639 volcengine#3765 * [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) * `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` * If this PR involves multiple modules, separate them with `,`, e.g., `[megatron, fsdp, doc]` * `{type}` ∈ {`feat`, `fix`, `refactor`, `chore`, `test`} * If this PR breaks any API, prepend `[BREAKING]` to the title. * Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that cannot be tested by CI (e.g., algorithm implementation, new model support), validate with experiments and include results such as training curves or evaluation metrics. Controlled experiments based on `examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh` (see [adapted script](https://gist.github.com/listar2000/43bb0e1d6f0d3c2503922ca2bfee0a6b)) **clearly demonstrate both the issue and the effectiveness of the fix**. <img width="2528" height="1328" alt="kl-loss" src="https://github.com/user-attachments/assets/008cdace-fc6d-459a-8493-8ddb440c57ec" /> <img width="2528" height="1328" alt="val-reward" src="https://github.com/user-attachments/assets/aa2e13c7-25cc-41cd-a916-d98f134060e6" /> See the full [W&B training log](https://wandb.ai/listar2000/verl-latest-lora). Summary: * **sync-lora-32** — baseline (synchronous mode). * **async-lora-32-before-fix** — async LoRA on `main` branch, showing degraded performance. * **async-lora-32-no-remove** — ablation variant with fixes applied **but without removing old LoRA adapters** between updates (showing the importance of removal). * **async-lora-32-after-fix** — full fix applied, achieving expected improvement. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). **Not Applicable** - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: **This PR can hardly be covered by regular CI. I instead run concrete experiments with GSM8K dataset.** - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…ine#3821) ### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. The previous volcengine#3639 addressed the **crashing issues** in `update_weights` of `vLLMAsyncRollout`. However, experiments (see **Tests** below) reveal an implicit **off-policy issue**: the rollout generation still uses the **base model** instead of the updated **LoRA model**, resulting in degraded performance. We traced this to a bug in `vllm_async_server.vLLMHttpServerBase` causing a mismatch between LoRA updates and rollout generation. Specifically: * In `vLLMAsyncRollout`, `update_weights` correctly updates LoRA weights from the FSDP actor to the rollout `AsyncLLM` engine. However, the updated adapter is assigned a random `lora_name` and `lora_int_id` (generated from `time.ns()`), which are not stored—making them hard to reuse. https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L595-L604 * During rollout generation, the newly added LoRA adapter is **never used** due to two issues: 1. The `vllm_config` used to create `AsyncLLM` lacks a `LoRAConfig` (e.g., `max_lora_rank`), so `AsyncLLM` is not prepared for LoRA-based generation requests. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L299-L304 2. When calling `generate` in `vLLMHttpServerBase`, the request to `self.engine` (the `AsyncLLM` instance) **omits any `LoRARequest`**, meaning generation always uses the base model. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L360 #### Proposed Fixes in this PR * Standardize and persist `VLLM_LORA_INT_ID` and `VLLM_LORA_NAME` across the training process to consistently locate and apply updated LoRA weights. * Inject `LoRAConfig` during `AsyncLLM` initialization and ensure `vLLMHttpServerBase` passes a proper `LoRARequest` (identified via `VLLM_LORA_NAME`) during rollout generation. * Add utility methods to automatically validate and set `max_lora_rank` in vLLM from `config.actor_rollout_ref.model.lora_rank`, addressing issues like volcengine#3696 #### Remarks Special thanks to @sanxing-chen for inspiring this fix with his prior patches. Also his PR volcengine#3765 -- while also tackling an issue hurting LoRA performance -- seems to be orthogonal to the issues addressed here. ### Checklist Before Starting * [x] Search for similar PRs. Paste at least one query link here: volcengine#3639 volcengine#3765 * [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) * `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` * If this PR involves multiple modules, separate them with `,`, e.g., `[megatron, fsdp, doc]` * `{type}` ∈ {`feat`, `fix`, `refactor`, `chore`, `test`} * If this PR breaks any API, prepend `[BREAKING]` to the title. * Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that cannot be tested by CI (e.g., algorithm implementation, new model support), validate with experiments and include results such as training curves or evaluation metrics. Controlled experiments based on `examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh` (see [adapted script](https://gist.github.com/listar2000/43bb0e1d6f0d3c2503922ca2bfee0a6b)) **clearly demonstrate both the issue and the effectiveness of the fix**. <img width="2528" height="1328" alt="kl-loss" src="https://github.com/user-attachments/assets/008cdace-fc6d-459a-8493-8ddb440c57ec" /> <img width="2528" height="1328" alt="val-reward" src="https://github.com/user-attachments/assets/aa2e13c7-25cc-41cd-a916-d98f134060e6" /> See the full [W&B training log](https://wandb.ai/listar2000/verl-latest-lora). Summary: * **sync-lora-32** — baseline (synchronous mode). * **async-lora-32-before-fix** — async LoRA on `main` branch, showing degraded performance. * **async-lora-32-no-remove** — ablation variant with fixes applied **but without removing old LoRA adapters** between updates (showing the importance of removal). * **async-lora-32-after-fix** — full fix applied, achieving expected improvement. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). **Not Applicable** - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: **This PR can hardly be covered by regular CI. I instead run concrete experiments with GSM8K dataset.** - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
What does this PR do?
Currently, async vLLM with AgentWorkerLoop throws an error when
update_weightswith LoRA weights. This expands support for AgentWorkerLoop with LoRAs.Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)