Skip to content

Commit 1546ce2

Browse files
[rollout, vllm] fix: make LoRA with async vLLM work properly (#3821)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. The previous #3639 addressed the **crashing issues** in `update_weights` of `vLLMAsyncRollout`. However, experiments (see **Tests** below) reveal an implicit **off-policy issue**: the rollout generation still uses the **base model** instead of the updated **LoRA model**, resulting in degraded performance. We traced this to a bug in `vllm_async_server.vLLMHttpServerBase` causing a mismatch between LoRA updates and rollout generation. Specifically: * In `vLLMAsyncRollout`, `update_weights` correctly updates LoRA weights from the FSDP actor to the rollout `AsyncLLM` engine. However, the updated adapter is assigned a random `lora_name` and `lora_int_id` (generated from `time.ns()`), which are not stored—making them hard to reuse. https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py#L595-L604 * During rollout generation, the newly added LoRA adapter is **never used** due to two issues: 1. The `vllm_config` used to create `AsyncLLM` lacks a `LoRAConfig` (e.g., `max_lora_rank`), so `AsyncLLM` is not prepared for LoRA-based generation requests. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L299-L304 2. When calling `generate` in `vLLMHttpServerBase`, the request to `self.engine` (the `AsyncLLM` instance) **omits any `LoRARequest`**, meaning generation always uses the base model. See https://github.com/volcengine/verl/blob/f209c6f656bb8444e1ecd641c1af04231a5a2dec/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L360 #### Proposed Fixes in this PR * Standardize and persist `VLLM_LORA_INT_ID` and `VLLM_LORA_NAME` across the training process to consistently locate and apply updated LoRA weights. * Inject `LoRAConfig` during `AsyncLLM` initialization and ensure `vLLMHttpServerBase` passes a proper `LoRARequest` (identified via `VLLM_LORA_NAME`) during rollout generation. * Add utility methods to automatically validate and set `max_lora_rank` in vLLM from `config.actor_rollout_ref.model.lora_rank`, addressing issues like #3696 #### Remarks Special thanks to @sanxing-chen for inspiring this fix with his prior patches. Also his PR #3765 -- while also tackling an issue hurting LoRA performance -- seems to be orthogonal to the issues addressed here. ### Checklist Before Starting * [x] Search for similar PRs. Paste at least one query link here: #3639 #3765 * [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) * `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` * If this PR involves multiple modules, separate them with `,`, e.g., `[megatron, fsdp, doc]` * `{type}` ∈ {`feat`, `fix`, `refactor`, `chore`, `test`} * If this PR breaks any API, prepend `[BREAKING]` to the title. * Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that cannot be tested by CI (e.g., algorithm implementation, new model support), validate with experiments and include results such as training curves or evaluation metrics. Controlled experiments based on `examples/grpo_trainer/run_qwen2_5-3b_gsm8k_grpo_lora.sh` (see [adapted script](https://gist.github.com/listar2000/43bb0e1d6f0d3c2503922ca2bfee0a6b)) **clearly demonstrate both the issue and the effectiveness of the fix**. <img width="2528" height="1328" alt="kl-loss" src="https://github.com/user-attachments/assets/008cdace-fc6d-459a-8493-8ddb440c57ec" /> <img width="2528" height="1328" alt="val-reward" src="https://github.com/user-attachments/assets/aa2e13c7-25cc-41cd-a916-d98f134060e6" /> See the full [W&B training log](https://wandb.ai/listar2000/verl-latest-lora). Summary: * **sync-lora-32** — baseline (synchronous mode). * **async-lora-32-before-fix** — async LoRA on `main` branch, showing degraded performance. * **async-lora-32-no-remove** — ablation variant with fixes applied **but without removing old LoRA adapters** between updates (showing the importance of removal). * **async-lora-32-after-fix** — full fix applied, achieving expected improvement. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). **Not Applicable** - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: **This PR can hardly be covered by regular CI. I instead run concrete experiments with GSM8K dataset.** - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent f209c6f commit 1546ce2

File tree

4 files changed

+81
-9
lines changed

4 files changed

+81
-9
lines changed

verl/utils/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,4 +194,8 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
194194
"validation gen temperature should be greater than 0 when enabling do_sample"
195195
)
196196

197+
# check LoRA rank in vLLM
198+
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0 and config.actor_rollout_ref.rollout.name == "vllm":
199+
assert config.actor_rollout_ref.model.lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
200+
197201
print("[validate_config] All configuration checks passed successfully!")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# magic numbers that ensure we are using the same LoRA adapter during the rollout and training process
16+
VLLM_LORA_INT_ID = 123
17+
VLLM_LORA_NAME = "123"
18+
VLLM_LORA_PATH = "simon_lora_path"
19+
20+
21+
def get_vllm_max_lora_rank(lora_rank: int):
22+
"""
23+
For vLLM, the smallest `max_lora_rank` is 8, and allowed values are (8, 16, 32, 64, 128, 256, 320, 512)
24+
This function automatically adjusts the `max_lora_rank` to the nearest allowed value.
25+
26+
Reference: https://github.com/vllm-project/vllm/blob/8a297115e2367d463b781adb86b55ac740594cf6/vllm/config/lora.py#L27
27+
"""
28+
assert lora_rank > 0, f"lora_rank must be greater than 0 to invoke this function, get {lora_rank}"
29+
vllm_max_lora_ranks = [8, 16, 32, 64, 128, 256, 320, 512]
30+
for rank in vllm_max_lora_ranks:
31+
if lora_rank <= rank:
32+
return rank
33+
34+
raise ValueError(f"lora_rank must be less than or equal to {vllm_max_lora_ranks[-1]}, but got {lora_rank}")

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
init_app_state,
3333
)
3434
from vllm.inputs import TokensPrompt
35+
from vllm.lora.request import LoRARequest
3536
from vllm.outputs import RequestOutput
3637
from vllm.usage.usage_lib import UsageContext
3738
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
@@ -46,6 +47,12 @@
4647
from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput
4748
from verl.workers.rollout.utils import get_free_port, run_unvicorn
4849
from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout
50+
from verl.workers.rollout.vllm_rollout.utils import (
51+
VLLM_LORA_INT_ID,
52+
VLLM_LORA_NAME,
53+
VLLM_LORA_PATH,
54+
get_vllm_max_lora_rank,
55+
)
4956

5057
logger = logging.getLogger(__file__)
5158
logger.setLevel(logging.INFO)
@@ -247,6 +254,16 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
247254
}
248255
)
249256

257+
# update lora-related args
258+
if self.model_config.lora_rank > 0:
259+
args.update(
260+
{
261+
"enable_lora": True,
262+
"max_loras": 1,
263+
"max_lora_rank": get_vllm_max_lora_rank(self.model_config.lora_rank),
264+
}
265+
)
266+
250267
server_args = ["serve", self.model_config.local_path]
251268
for k, v in args.items():
252269
if isinstance(v, bool):
@@ -357,7 +374,15 @@ async def generate(
357374
prompt = TokensPrompt(
358375
prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None
359376
)
360-
generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)
377+
378+
# Add lora request
379+
lora_request = None
380+
if self.model_config.lora_rank > 0:
381+
lora_request = LoRARequest(lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH)
382+
383+
generator = self.engine.generate(
384+
prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request
385+
)
361386

362387
# Get final response
363388
final_res: Optional[RequestOutput] = None

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@
6969
from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge
7070
from verl.workers.config import HFModelConfig, RolloutConfig
7171
from verl.workers.rollout.base import BaseRollout
72+
from verl.workers.rollout.vllm_rollout.utils import (
73+
VLLM_LORA_INT_ID,
74+
VLLM_LORA_NAME,
75+
VLLM_LORA_PATH,
76+
get_vllm_max_lora_rank,
77+
)
7278

7379
logger = logging.getLogger(__file__)
7480
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -112,7 +118,7 @@ def __init__(
112118
model_hf_config = model_config.hf_config
113119
trust_remote_code = model_config.trust_remote_code
114120
self.lora_kwargs = (
115-
{"enable_lora": True, "max_loras": 1, "max_lora_rank": model_config.lora_rank}
121+
{"enable_lora": True, "max_loras": 1, "max_lora_rank": get_vllm_max_lora_rank(model_config.lora_rank)}
116122
if model_config.lora_rank > 0
117123
else {}
118124
)
@@ -487,7 +493,9 @@ def __init__(
487493
self.inference_engine: WorkerWrapperBase = None
488494
self.address = self._init_zeromq()
489495
self.lora_config = (
490-
{"max_loras": 1, "max_lora_rank": model_config.lora_rank} if model_config.lora_rank > 0 else {}
496+
{"max_loras": 1, "max_lora_rank": get_vllm_max_lora_rank(model_config.lora_rank)}
497+
if model_config.lora_rank > 0
498+
else {}
491499
)
492500

493501
# https://github.com/vllm-project/vllm/issues/25171
@@ -593,15 +601,16 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
593601
"""
594602
peft_config, base_sync_done = kwargs.get("peft_config", None), kwargs.get("base_sync_done", False)
595603
if peft_config and base_sync_done:
596-
lora_int_id = int(time.time_ns() % 0x7FFFFFFF)
597-
lora_reqest = TensorLoRARequest(
598-
lora_name=f"{lora_int_id}",
599-
lora_int_id=lora_int_id,
600-
lora_path="simon_lora_path",
604+
# In async mode, make sure the old lora is removed before adding the new one
605+
self.inference_engine.worker.remove_lora(VLLM_LORA_INT_ID)
606+
lora_request = TensorLoRARequest(
607+
lora_name=VLLM_LORA_NAME,
608+
lora_int_id=VLLM_LORA_INT_ID,
609+
lora_path=VLLM_LORA_PATH,
601610
peft_config=asdict(peft_config),
602611
lora_tensors=dict(weights),
603612
)
604-
self.inference_engine.worker.add_lora(lora_reqest)
613+
self.inference_engine.worker.add_lora(lora_request)
605614
logger.info(f"vLLM load weights, loaded_params: {len(weights)}")
606615
else:
607616
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader

0 commit comments

Comments
 (0)