Skip to content

Commit fa02416

Browse files
[rollout] feat: Support Multi-stage Awake for SGLang (verl-project#1911)
Co-authored with: MrAta (immrata@gmail.com) ### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? ### Motivation In RL Ecosystem which use colocate design like [verl](https://github.com/volcengine/verl/tree/main), we need to offload training model and load serving model & KV Cache frequently. #### Background - Currently SGLang is using [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) to pause and resume. - [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) is a open source repo that provided easy to use api to hack **cudaMalloc** and **cudaFree** to make sure the virtual address could be consistent after pause and resume, which is critical to ensure CUDA Graph work. - CUDA Graph is critical to make sure SGLang runs faster in decoding phases. #### Here is the current behavior of VERL + SGLang ![Image](https://github.com/user-attachments/assets/e87e7dd6-f223-4de6-8f07-915eb2030ea8) 1. During Training, we have training model and optimizer state in the GPU Memory, and once training is done, we will offload optimizer state to cpu and keep the model weights in GPU, which is needed in Update Weight. 2. During Update Weight, we awake the SGLang engine, so those paused memory of Model Weights and KV Cache will come back. Then we update model from training model to serving model on the fly using the api: `update_weights_in_tensor` 3. After Model being updated, we delete the training model from GPU Memory. Above design works pretty well so far, however, this would waste a big chunk of GPU Memory during rollout, which could cause a few issues we've seen so far: - **Small KV Cache**: We need to use relative lower number of mem fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV Cache has less tokens, we will hit `RuntimeError: Prefill out of memory. Try to lower your batch size.` when we try prefill large number of requests. - **Out of Memory**: If we use mem fraction ratio 0.8 and run RL for 32B model on 8 H100, it will OOM during update weight #### Challenge - `torch_memory_saver` currently only supports Singleton, hence SGLang will pause and resume KV Cache + Weights together, they are treated as the same group of memory controlled by the singleton `torch_memory_saver` instance #### Proposal ![Image](https://github.com/user-attachments/assets/7fda9638-0dc2-4c14-bc64-cd20616f350f) 1. During Training, we do the same 2. During Update Weight Stage 1, we awake the model weights from SGLang and then update weights 3. During Update Weight Stage 2, we delete the training model weights from GPU Memory 4. Awake the SGLang's KV Cache ![Image](https://github.com/user-attachments/assets/f3dab327-dc2e-4ed8-88d7-15e383f77d25) ### Benefit With above feature, we can train larger model with same GPU, we can also make training/rollout more efficient given we can allocate larger KV Cache ### Solution: Keep using Singleton and provide tag based pause/resume - [x] Support tag based resume/pause: fzyzcjy/torch_memory_saver#20 - [x] Support Multiple Stage Awake in SGLang: sgl-project/sglang#7099 - [ ] Support Multiple Stage Awake in verl: verl-project#1911 ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Test ![Screenshot 2025-06-19 at 12 16 19 PM](https://github.com/user-attachments/assets/a95dd57e-43e1-4f28-8a84-003ec5c043fc) ![Screenshot 2025-06-19 at 12 13 14 PM](https://github.com/user-attachments/assets/f1f4a8a8-1845-4fad-9424-5526d4154dd0) ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] New CI unit test(s) are added to cover the code path. - [ ] Rely on existing unit tests on CI that covers the code path. --------- Co-authored-by: Chayenne <zhaochen20@outlook.com>
1 parent d084f2d commit fa02416

4 files changed

Lines changed: 31 additions & 9 deletions

File tree

verl/trainer/config/ppo_trainer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,9 @@ actor_rollout_ref:
465465
# number of responses (i.e. num sample times). > 1 for grpo
466466
n: 1
467467

468+
# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache)
469+
multi_stage_wake_up: false
470+
468471
# Extra inference engine arguments (vllm, sglang).
469472
engine_kwargs:
470473

verl/workers/fsdp_workers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def _build_rollout(self, trust_remote_code=False):
484484
full_params="hf" in self.config.rollout.load_format,
485485
device_mesh=rollout_device_mesh,
486486
offload_param=self._is_offload_param,
487+
multi_stage_wake_up=self.config.rollout.multi_stage_wake_up,
487488
)
488489
log_gpu_memory_usage("After building sharding manager", logger=logger)
489490

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,21 +132,27 @@ def __init__(self, **kwargs):
132132
# default to use dummy load format, which need to reload weights in first time
133133
self._need_reload = True
134134

135-
async def release_memory_occupation(self):
135+
async def release_memory_occupation(self, tags: Optional[list[str]] = None):
136136
"""Release GPU occupation temporarily."""
137-
obj = ReleaseMemoryOccupationReqInput()
137+
if tags is None:
138+
obj = ReleaseMemoryOccupationReqInput()
139+
else:
140+
obj = ReleaseMemoryOccupationReqInput(tags=tags)
138141
return await self.tokenizer_manager.release_memory_occupation(obj, None)
139142

140-
async def resume_memory_occupation(self):
143+
async def resume_memory_occupation(self, tags: Optional[list[str]] = None):
141144
"""Resume GPU occupation."""
142-
143145
# because __init__ is a sync method, it can not call the async release_memory_occupation
144146
# have to move release_memory_occupation from __init__ to here
147+
# For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time.
145148
if self._need_reload:
146149
await self.release_memory_occupation()
147150
self._need_reload = False
148151

149-
obj = ResumeMemoryOccupationReqInput()
152+
if tags is None:
153+
obj = ResumeMemoryOccupationReqInput()
154+
else:
155+
obj = ResumeMemoryOccupationReqInput(tags=tags)
150156
return await self.tokenizer_manager.resume_memory_occupation(obj, None)
151157

152158
async def update_weights_from_tensor(

verl/workers/sharding_manager/fsdp_sglang.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,14 @@ def __init__(
5959
full_params: bool = False,
6060
device_mesh: DeviceMesh = None,
6161
offload_param: bool = False,
62+
multi_stage_wake_up: bool = False,
6263
):
6364
self.module = module
6465
self.inference_engine = inference_engine
6566
self.model_config = model_config
6667
self.device_mesh = device_mesh
6768
self.offload_param = offload_param
69+
self.multi_stage_wake_up = multi_stage_wake_up
6870

6971
# Full params
7072
self.full_params = full_params
@@ -95,7 +97,17 @@ def __init__(
9597
def __enter__(self):
9698
self.timing = {}
9799
with simple_timer("reshard", self.timing):
100+
loop = asyncio.get_event_loop()
101+
102+
if self.device_mesh["infer_tp"].get_local_rank() == 0:
103+
if self.multi_stage_wake_up:
104+
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"]))
105+
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
106+
else:
107+
loop.run_until_complete(self.inference_engine.resume_memory_occupation())
108+
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
98109
get_torch_device().empty_cache()
110+
99111
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
100112
if self.offload_param:
101113
load_fsdp_model_to_gpu(self.module)
@@ -105,7 +117,6 @@ def __enter__(self):
105117
params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()}
106118
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
107119
# Copy, not share memory
108-
loop = asyncio.get_event_loop()
109120
loop.run_until_complete(self.update_weights(params))
110121
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
111122

@@ -115,6 +126,10 @@ def __enter__(self):
115126
get_torch_device().empty_cache()
116127
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
117128

129+
if self.multi_stage_wake_up:
130+
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"]))
131+
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)
132+
118133
# important: need to manually set the random states of each tp to be identical.
119134
if self.device_mesh is not None:
120135
self.torch_random_states = get_torch_device().get_rng_state()
@@ -138,9 +153,6 @@ def __exit__(self, exc_type, exc_value, traceback):
138153
get_torch_device().set_rng_state(self.torch_random_states)
139154

140155
async def update_weights(self, params):
141-
if self.device_mesh["infer_tp"].get_local_rank() == 0:
142-
await self.inference_engine.resume_memory_occupation()
143-
144156
# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
145157
named_tensors = [(k, v) for k, v in params.items()]
146158
load_format = None

0 commit comments

Comments
 (0)