From c3b82b2d434fcf8db29075175f6db80290d69010 Mon Sep 17 00:00:00 2001 From: Chen Haiquan Date: Fri, 8 Aug 2025 14:48:16 +0800 Subject: [PATCH] offload param before wake up inference engine --- verl/workers/sharding_manager/fsdp_sglang.py | 23 +++++++++++--------- verl/workers/sharding_manager/fsdp_vllm.py | 5 +++-- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index b4f97968ef4..61032c68b76 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -126,14 +126,6 @@ async def release_memory(self): async def wake_up(self): get_torch_device().empty_cache() - if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: - if self.multi_stage_wake_up: - await self.inference_engine.resume_memory_occupation(tags=["weights"]) - log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) - else: - await self.inference_engine.resume_memory_occupation() - log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) - log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) if self.offload_param: load_fsdp_model_to_gpu(self.module) @@ -147,13 +139,24 @@ async def wake_up(self): # convert weight keys to match the model config params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) + + log_gpu_memory_usage("After offload_param in sharding manager memory", logger=logger) + + if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: + if self.multi_stage_wake_up: + await self.inference_engine.resume_memory_occupation(tags=["weights"]) + log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) + else: + await self.inference_engine.resume_memory_occupation() + log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) + # Copy, not share memory await self.update_weights(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params - if self.offload_param: - offload_fsdp_model_to_cpu(self.module) get_torch_device().empty_cache() log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index c9b163a0692..73866ebb21c 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -205,6 +205,9 @@ def __collect_lora_params() -> OrderedDict: else: params = self.module.state_dict() params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) + + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) if self.rollout_config.free_cache_engine: @@ -217,8 +220,6 @@ def __collect_lora_params() -> OrderedDict: self.update_params(params, peft_config=peft_config) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params - if self.offload_param: - offload_fsdp_model_to_cpu(self.module) get_torch_device().empty_cache() if (