Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,6 @@ async def release_memory(self):
async def wake_up(self):
get_torch_device().empty_cache()

if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.multi_stage_wake_up:
await self.inference_engine.resume_memory_occupation(tags=["weights"])
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
await self.inference_engine.resume_memory_occupation()
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)

log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_fsdp_model_to_gpu(self.module)
Expand All @@ -147,13 +139,24 @@ async def wake_up(self):
# convert weight keys to match the model config
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))

if self.offload_param:
offload_fsdp_model_to_cpu(self.module)

log_gpu_memory_usage("After offload_param in sharding manager memory", logger=logger)

if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.multi_stage_wake_up:
await self.inference_engine.resume_memory_occupation(tags=["weights"])
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
await self.inference_engine.resume_memory_occupation()
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)

# Copy, not share memory
await self.update_weights(params)
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)

del params
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
get_torch_device().empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

Expand Down
5 changes: 3 additions & 2 deletions verl/workers/sharding_manager/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def __collect_lora_params() -> OrderedDict:
else:
params = self.module.state_dict()
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))

if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
Comment on lines +209 to 211
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The log message "After state_dict()" on line 211 is misleading, as it's now called after the parameters may have been offloaded to CPU. This can cause confusion when debugging memory usage. Renaming it to "After offload_param..." makes it accurate and consistent with the new log message added in fsdp_sglang.py.

For even better diagnostics, you could consider having separate log points for after state_dict and after offload_param, similar to the implementation in fsdp_sglang.py.

Suggested change
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After offload_param in sharding manager memory", logger=logger)


if self.rollout_config.free_cache_engine:
Expand All @@ -217,8 +220,6 @@ def __collect_lora_params() -> OrderedDict:
self.update_params(params, peft_config=peft_config)
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
del params
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
get_torch_device().empty_cache()

if (
Expand Down