Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
from verl.third_party.vllm import VLLM_SLEEP_LEVEL
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.activation_offload import enable_activation_offloading
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
Expand Down Expand Up @@ -612,11 +611,6 @@ def _build_rollout(self, trust_remote_code=False):
# used for LoRA
self.base_sync_done: bool = "dummy" not in self.config.rollout.load_format
self.layered_summon = self.config.rollout.get("layered_summon", False)
if VLLM_SLEEP_LEVEL == 2 and not self.layered_summon:
self.force_reload = True
self.base_sync_done = False
else:
self.force_reload = False

# 5. switch to trainer mode
# NOTE: It's critical that hybrid engine in trainer mode initially to load checkpoint.
Expand Down Expand Up @@ -653,6 +647,21 @@ async def rollout_mode(self):
params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
)

# Special handling for LoRA with sleep_level=2:
# When sleep_level=2, base model weights are destroyed during each sleep cycle.
# separately collect and update LoRA weights and base model weights through their respective interfaces.
# Here: params contains LoRA weights, base_model_params contains base model weights.
if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
base_model_params = collect_lora_params(
module=self.actor_module_fsdp,
layered_summon=self.layered_summon,
base_sync_done=False,
)
base_model_params = {replace_lora_wrapper(k, peft_config): v for k, v in base_model_params.items()}
base_model_params = convert_weight_keys(
base_model_params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
)

log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
Expand All @@ -672,6 +681,15 @@ async def rollout_mode(self):
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
log_gpu_memory_usage("After resume weights", logger=logger)

if peft_config is not None and getattr(self.rollout, "sleep_level", None) == 2:
per_tensor_base_params = (
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in base_model_params.items()
)
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
del base_model_params, per_tensor_base_params

await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
log_gpu_memory_usage("After update_weights", logger=logger)
del params, per_tensor_param
Expand All @@ -680,7 +698,7 @@ async def rollout_mode(self):
await self.rollout.resume(tags=["kv_cache"])
log_gpu_memory_usage("After resume kv_cache", logger=logger)

self.base_sync_done = not self.force_reload
self.base_sync_done = True
# important: need to manually set the random states of each tp to be identical.
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)
Expand Down
Loading