Skip to content

Commit 96e7071

Browse files
authored
[trainer,rollout] fix: ensure LoRA weights are loaded when vllm_sleep_level=2 and without using layerd_summon (#3541)
### What does this PR do? Fix issue where VLLM would only load base model parameters and not LoRA parameters when VLLM_SLEEP_LEVEL == 2 and not using layered_summon. This fixes the LoRA trainer error where the first rollout would only use base model parameters, and subsequent rollouts would correctly load LoRA parameters. Fixes: #3516 Related PR: #3461 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 7e4eec7 commit 96e7071

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

verl/workers/fsdp_workers.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
from verl.models.transformers.monkey_patch import apply_monkey_patch
4949
from verl.single_controller.base import Worker
5050
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
51-
from verl.third_party.vllm import VLLM_SLEEP_LEVEL
5251
from verl.utils import hf_processor, hf_tokenizer
5352
from verl.utils.activation_offload import enable_activation_offloading
5453
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
@@ -612,11 +611,6 @@ def _build_rollout(self, trust_remote_code=False):
612611
# used for LoRA
613612
self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format
614613
self.layered_summon = self.config.rollout.get("layered_summon", False)
615-
if VLLM_SLEEP_LEVEL == 2 and not self.layered_summon:
616-
self.force_reload = True
617-
self.base_sync_done = False
618-
else:
619-
self.force_reload = False
620614

621615
# 5. switch to trainer mode
622616
# NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.
@@ -653,6 +647,21 @@ async def rollout_mode(self):
653647
params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
654648
)
655649

650+
# Special handling for LoRA with sleep_level=2:
651+
# When sleep_level=2, base model weights are destroyed during each sleep cycle.
652+
# separately collect and update LoRA weights and base model weights through their respective interfaces.
653+
# Here: params contains LoRA weights, base_model_params contains base model weights.
654+
if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
655+
base_model_params = collect_lora_params(
656+
module=self.actor_module_fsdp,
657+
layered_summon=self.layered_summon,
658+
base_sync_done=False,
659+
)
660+
base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()}
661+
base_model_params = convert_weight_keys(
662+
base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
663+
)
664+
656665
log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
657666
if self._is_offload_param:
658667
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
@@ -672,6 +681,15 @@ async def rollout_mode(self):
672681
if self.config.rollout.free_cache_engine:
673682
await self.rollout.resume(tags=["weights"])
674683
log_gpu_memory_usage("After resume weights", logger=logger)
684+
685+
if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
686+
per_tensor_base_params = (
687+
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
688+
for name, param in base_model_params.items()
689+
)
690+
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
691+
del base_model_params, per_tensor_base_params
692+
675693
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
676694
log_gpu_memory_usage("After update_weights", logger=logger)
677695
del params, per_tensor_param
@@ -680,7 +698,7 @@ async def rollout_mode(self):
680698
await self.rollout.resume(tags=["kv_cache"])
681699
log_gpu_memory_usage("After resume kv_cache", logger=logger)
682700

683-
self.base_sync_done = not self.force_reload
701+
self.base_sync_done = True
684702
# important: need to manually set the random states of each tp to be identical.
685703
self.torch_random_states = get_torch_device().get_rng_state()
686704
get_torch_device().set_rng_state(self.gen_random_states)

0 commit comments

Comments
 (0)