Skip to content

Commit da7fc8e

Browse files
authored
[rollout,trainer] feat: offload param before wake up inference engine (#2977)
1 parent beb6246 commit da7fc8e

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

verl/workers/sharding_manager/fsdp_sglang.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,6 @@ async def release_memory(self):
126126
async def wake_up(self):
127127
get_torch_device().empty_cache()
128128

129-
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
130-
if self.multi_stage_wake_up:
131-
await self.inference_engine.resume_memory_occupation(tags=["weights"])
132-
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
133-
else:
134-
await self.inference_engine.resume_memory_occupation()
135-
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
136-
137129
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
138130
if self.offload_param:
139131
load_fsdp_model_to_gpu(self.module)
@@ -147,13 +139,24 @@ async def wake_up(self):
147139
# convert weight keys to match the model config
148140
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
149141

142+
if self.offload_param:
143+
offload_fsdp_model_to_cpu(self.module)
144+
145+
log_gpu_memory_usage("After offload_param in sharding manager memory", logger=logger)
146+
147+
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
148+
if self.multi_stage_wake_up:
149+
await self.inference_engine.resume_memory_occupation(tags=["weights"])
150+
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
151+
else:
152+
await self.inference_engine.resume_memory_occupation()
153+
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
154+
150155
# Copy, not share memory
151156
await self.update_weights(params)
152157
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
153158

154159
del params
155-
if self.offload_param:
156-
offload_fsdp_model_to_cpu(self.module)
157160
get_torch_device().empty_cache()
158161
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
159162

verl/workers/sharding_manager/fsdp_vllm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ def __collect_lora_params() -> OrderedDict:
205205
else:
206206
params = self.module.state_dict()
207207
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
208+
209+
if self.offload_param:
210+
offload_fsdp_model_to_cpu(self.module)
208211
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
209212

210213
if self.rollout_config.free_cache_engine:
@@ -217,8 +220,6 @@ def __collect_lora_params() -> OrderedDict:
217220
self.update_params(params, peft_config=peft_config)
218221
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
219222
del params
220-
if self.offload_param:
221-
offload_fsdp_model_to_cpu(self.module)
222223
get_torch_device().empty_cache()
223224

224225
if (

0 commit comments

Comments
 (0)