-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[rollout,trainer] feat: offload param before wake up inference engine #2977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the parameter handling logic in fsdp_sglang.py and fsdp_vllm.py to reduce peak GPU memory usage. The change involves offloading the FSDP model's parameters to the CPU before waking up the inference engine. This is a sensible optimization that should prevent having both the training model parameters and inference model parameters on the GPU simultaneously. The implementation looks correct and consistent across both modified files. I have one suggestion for fsdp_vllm.py to improve the clarity of a log message, which will aid in future debugging of memory-related issues.
| if self.offload_param: | ||
| offload_fsdp_model_to_cpu(self.module) | ||
| log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The log message "After state_dict()" on line 211 is misleading, as it's now called after the parameters may have been offloaded to CPU. This can cause confusion when debugging memory usage. Renaming it to "After offload_param..." makes it accurate and consistent with the new log message added in fsdp_sglang.py.
For even better diagnostics, you could consider having separate log points for after state_dict and after offload_param, similar to the implementation in fsdp_sglang.py.
| if self.offload_param: | |
| offload_fsdp_model_to_cpu(self.module) | |
| log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) | |
| if self.offload_param: | |
| offload_fsdp_model_to_cpu(self.module) | |
| log_gpu_memory_usage("After offload_param in sharding manager memory", logger=logger) |
|
sorry, just noticed it, the PR lgtm! Do you have idea about why it failed for megatron? |
|
And whats your setup? it shows that it only saved <1GB, I wonder how many gpus and what model you are using |
Qwen2.5-7B-Instruct with 8 A800 GPU |
The failed test is fsdp only testcase. It has no relation to rollout, including vllm and sglang, so my change has no effect to it. |
What does this PR do?
Move the parameter offloading step before waking up the inference engine to reduce GPU memory cap.
Changed vllm and sglang with fsdp.
Leaving megatron unchanged because it may result illegal access with similar change.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
With this change, the max GPU memory is reduced from logs.
sglang
Now:
(WorkerDict pid=12436) [2025-08-08 07:16:19] Before state_dict() in sharding manager memory, memory allocated (GB): 0.02, memory reserved (GB): 0.87, device memory used/total (GB): 5.50/79.15
(WorkerDict pid=12436) [2025-08-08 07:16:19] After state_dict() in sharding manager memory, memory allocated (GB): 1.57, memory reserved (GB): 3.44, device memory used/total (GB): 8.07/79.15
(WorkerDict pid=12436) [2025-08-08 07:16:19] After offload_param in sharding manager memory, memory allocated (GB): 0.85, memory reserved (GB): 2.52, device memory used/total (GB): 7.14/79.15 <---
(WorkerDict pid=12436) [2025-08-08 07:16:19] Before resume SGLang weights + kv_cache in sharding manager, memory allocated (GB): 0.85, memory reserved (GB): 2.52, device memory used/total (GB): 53.72/79.15. <--
Before
(WorkerDict pid=31787) [2025-08-08 07:31:42] Before state_dict() in sharding manager memory, memory allocated (GB): 0.02, memory reserved (GB): 0.87, device memory used/total (GB): 52.06/79.15
(WorkerDict pid=31787) [2025-08-08 07:31:42] After state_dict() in sharding manager memory, memory allocated (GB): 1.57, memory reserved (GB): 3.44, device memory used/total (GB): 54.63/79.15. <---
(WorkerDict pid=31787) [2025-08-08 07:31:43] After sync model weights in sharding manager, memory allocated (GB): 2.44, memory reserved (GB): 3.44, device memory used/total (GB): 54.61/79.15
vllm
Now
(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,191:Before state_dict() in sharding manager memory, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 2.74/79.15
(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,472:After state_dict() in sharding manager memory, memory allocated (GB): 46.06, memory reserved (GB): 47.72, device memory used/total (GB): 5.14/79.15
(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,637:After sync model weights in sharding manager, memory allocated (GB): 46.06, memory reserved (GB): 47.72, device memory used/total (GB): 5.92/79.15 <--
(WorkerDict pid=87197) DEBUG:2025-08-07 11:37:15,791:After del state_dict and empty_cache in sharding manager, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 47.94/79.15
Before
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:46,431:Before state_dict() in sharding manager memory, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 2.74/79.15
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:46,628:After state_dict() in sharding manager memory, memory allocated (GB): 46.78, memory reserved (GB): 48.76, device memory used/total (GB): 6.17/79.15 <--
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:46,790:After sync model weights in sharding manager, memory allocated (GB): 46.78, memory reserved (GB): 48.76, device memory used/total (GB): 6.96/79.15
(WorkerDict pid=104544) DEBUG:2025-08-07 11:41:47,073:After del state_dict and empty_cache in sharding manager, memory allocated (GB): 45.21, memory reserved (GB): 45.33, device memory used/total (GB): 47.94/79.15
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)