Skip to content

Conversation

@chenhaiq
Copy link
Collaborator

@chenhaiq chenhaiq commented Aug 8, 2025

What does this PR do?

Move the parameter offloading step before waking up the inference engine to reduce GPU memory cap.

Changed vllm and sglang with fsdp.
Leaving megatron unchanged because it may result illegal access with similar change.

whiteboard_exported_image (4)

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

With this change, the max GPU memory is reduced from logs.

sglang

Now:
(WorkerDict pid=12436) [2025-08-08 07:16:19] Before state_dict() in sharding manager memory, memory allocated (GB): 0.02, memory reserved (GB): 0.87, device memory used/total (GB): 5.50/79.15
(WorkerDict pid=12436) [2025-08-08 07:16:19] After state_dict() in sharding manager memory, memory allocated (GB): 1.57, memory reserved (GB): 3.44, device memory used/total (GB): 8.07/79.15
(WorkerDict pid=12436) [2025-08-08 07:16:19] After offload_param in sharding manager memory, memory allocated (GB): 0.85, memory reserved (GB): 2.52, device memory used/total (GB): 7.14/79.15 <---
(WorkerDict pid=12436) [2025-08-08 07:16:19] Before resume SGLang weights + kv_cache in sharding manager, memory allocated (GB): 0.85, memory reserved (GB): 2.52, device memory used/total (GB): 53.72/79.15. <--

Before
(WorkerDict pid=31787) [2025-08-08 07:31:42] Before state_dict() in sharding manager memory, memory allocated (GB): 0.02, memory reserved (GB): 0.87, device memory used/total (GB): 52.06/79.15
(WorkerDict pid=31787) [2025-08-08 07:31:42] After state_dict() in sharding manager memory, memory allocated (GB): 1.57, memory reserved (GB): 3.44, device memory used/total (GB): 54.63/79.15. <---
(WorkerDict pid=31787) [2025-08-08 07:31:43] After sync model weights in sharding manager, memory allocated (GB): 2.44, memory reserved (GB): 3.44, device memory used/total (GB): 54.61/79.15

vllm

Now

(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,191:Before state_dict() in sharding manager memory, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 2.74/79.15
(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,472:After state_dict() in sharding manager memory, memory allocated (GB): 46.06, memory reserved (GB): 47.72, device memory used/total (GB): 5.14/79.15
(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,637:After sync model weights in sharding manager, memory allocated (GB): 46.06, memory reserved (GB): 47.72, device memory used/total (GB): 5.92/79.15 <--
(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,791:After del state_dict and empty_cache in sharding manager, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 47.94/79.15

Before
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:46,431:Before state_dict() in sharding manager memory, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 2.74/79.15
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:46,628:After state_dict() in sharding manager memory, memory allocated (GB): 46.78, memory reserved (GB): 48.76, device memory used/total (GB): 6.17/79.15 <--
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:46,790:After sync model weights in sharding manager, memory allocated (GB): 46.78, memory reserved (GB): 48.76, device memory used/total (GB): 6.96/79.15
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:47,073:After del state_dict and empty_cache in sharding manager, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 47.94/79.15

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the parameter handling logic in fsdp_sglang.py and fsdp_vllm.py to reduce peak GPU memory usage. The change involves offloading the FSDP model's parameters to the CPU before waking up the inference engine. This is a sensible optimization that should prevent having both the training model parameters and inference model parameters on the GPU simultaneously. The implementation looks correct and consistent across both modified files. I have one suggestion for fsdp_vllm.py to improve the clarity of a log message, which will aid in future debugging of memory-related issues.

Comment on lines +209 to 211
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The log message "After state_dict()" on line 211 is misleading, as it's now called after the parameters may have been offloaded to CPU. This can cause confusion when debugging memory usage. Renaming it to "After offload_param..." makes it accurate and consistent with the new log message added in fsdp_sglang.py.

For even better diagnostics, you could consider having separate log points for after state_dict and after offload_param, similar to the implementation in fsdp_sglang.py.

Suggested change
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After offload_param in sharding manager memory", logger=logger)

@vermouth1992 vermouth1992 merged commit da7fc8e into volcengine:main Aug 8, 2025
34 of 38 checks passed
@hebiao064
Copy link
Collaborator

sorry, just noticed it, the PR lgtm!

Do you have idea about why it failed for megatron?

@hebiao064
Copy link
Collaborator

And whats your setup? it shows that it only saved <1GB, I wonder how many gpus and what model you are using

@chenhaiq
Copy link
Collaborator Author

chenhaiq commented Aug 11, 2025

And whats your setup? it shows that it only saved <1GB, I wonder how many gpus and what model you are using

Qwen2.5-7B-Instruct with 8 A800 GPU

@chenhaiq
Copy link
Collaborator Author

sorry, just noticed it, the PR lgtm!

Do you have idea about why it failed for megatron?

The failed test is fsdp only testcase. It has no relation to rollout, including vllm and sglang, so my change has no effect to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants