Skip to content

Commit cd734a4

Browse files
committed
[rollout] fix: sglang async fail with Multi-stage Awake feature (#2365)
### What does this PR do? Fix a regression from verl-project/verl#1911, because the PR did not change the sglang async branch. CI did not catch this error because it only run 1 step, but this error happen in the second test. So I update the testcases to run 2 steps. To reproduce the bug, run test: TOTAL_TRAIN_STEPS=2 ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh It fail with: ``` (WorkerDict pid=1257286) Total steps: 2, num_warmup_steps: 0 (WorkerDict pid=1257286) Actor use_remove_padding=True (WorkerDict pid=1257286) Actor use_fused_kernels=False (AsyncSglangServer pid=1260392) FastAPI listen on [192.168.111.48:40451](http://192.168.111.48:40451/) (WorkerDict pid=1257286) terminate called after throwing an instance of 'c10::Error' (WorkerDict pid=1257286) what(): CUDA error: an illegal memory access was encountered (WorkerDict pid=1257286) CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. (WorkerDict pid=1257286) For debugging consider passing CUDA_LAUNCH_BLOCKING=1 (WorkerDict pid=1257286) Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. (WorkerDict pid=1257286) (WorkerDict pid=1257286) Exception raised from c10_cuda_check_implementation at /pytorch/c10/cuda/CUDAException.cpp:43 (most recent call first): (WorkerDict pid=1257286) frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fbf6036c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/[libc10.so](http://libc10.so/)) (WorkerDict pid=1257286) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fbf60315a76 in /usr/local/lib/python3.10/dist-packages/torch/lib/[libc10.so](http://libc10.so/)) (WorkerDict pid=1257286) frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fbf6080d918 in ``` ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: https://github.com/volcengine/verl/issues?q=is%3Aissue%20state%3Aopen%20an%20illegal%20memory%20access%20was%20encountered - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test ``` (TaskRunner pid=1647269) step:2 - global_seqlen/min:13075 - global_seqlen/max:14837 - global_seqlen/minmax_diff:1762 - global_seqlen/balanced_min:14231 - global_seqlen/balanced_max:14232 - global_seqlen/mean:14231.5 - actor/entropy:2.0606913566589355 - critic/vf_loss:8.7157882153 ``` ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ X] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
1 parent be9c8cb commit cd734a4

File tree

4 files changed

+34
-96
lines changed

4 files changed

+34
-96
lines changed

.github/workflows/e2e_ppo_trainer.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,13 @@ jobs:
292292
- name: Running GSM8K E2E training tests on sglang async
293293
run: |
294294
ray stop --force
295-
ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
295+
TOTAL_TRAIN_STEPS=2 ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
296296
- name: Running GSM8K E2E training tests on vllm async
297297
run: |
298298
ray stop --force
299299
export VLLM_USE_V1=1
300300
ray start --head
301-
ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
301+
TOTAL_TRAIN_STEPS=2 ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
302302
303303
e2e_ppo_trainer_sglang_multiturn_with_tool:
304304
runs-on: [L20x8]

verl/workers/megatron_workers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
import time
22-
from typing import Optional, Union
22+
from typing import Any, Dict, List, Optional, Union
2323

2424
import psutil
2525
import torch
@@ -692,6 +692,11 @@ async def chat_completion(self, json_request):
692692
ret = await self.rollout.chat_completion(json_request)
693693
return ret
694694

695+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
696+
async def generate(self, prompt_ids: List[int], sampling_params: Dict[str, Any], request_id: str) -> List[int]:
697+
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id)
698+
return ret
699+
695700
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
696701
async def wake_up(self):
697702
if self.config.rollout.free_cache_engine:

verl/workers/sharding_manager/fsdp_sglang.py

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from verl.protocol import all_gather_data_proto
3333
from verl.utils.device import get_device_id, get_torch_device
3434
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
35-
from verl.utils.model import convert_weight_keys
3635
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
3736
from verl.utils.torch_functional import check_device_is_available
3837

@@ -101,65 +100,13 @@ def __init__(
101100
def __enter__(self):
102101
self.timing = {}
103102
with simple_timer("reshard", self.timing):
104-
get_torch_device().empty_cache()
105-
106103
loop = asyncio.get_event_loop()
107-
108-
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
109-
if self.multi_stage_wake_up:
110-
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"]))
111-
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
112-
else:
113-
loop.run_until_complete(self.inference_engine.resume_memory_occupation())
114-
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
115-
get_torch_device().empty_cache()
116-
117-
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
118-
if self.offload_param:
119-
load_fsdp_model_to_gpu(self.module)
120-
params = self.module.state_dict()
121-
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
122-
device = get_device_id() # used when fsdp2 set cpu_offload_policy
123-
params = {
124-
k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()
125-
}
126-
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
127-
# Copy, not share memory
128-
loop.run_until_complete(self.update_weights(params))
129-
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
130-
131-
del params
132-
if self.offload_param:
133-
offload_fsdp_model_to_cpu(self.module)
134-
get_torch_device().empty_cache()
135-
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
136-
137-
if self.multi_stage_wake_up and self.rollout_config.free_cache_engine:
138-
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"]))
139-
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)
140-
141-
# important: need to manually set the random states of each tp to be identical.
142-
if self.device_mesh is not None:
143-
self.torch_random_states = get_torch_device().get_rng_state()
144-
get_torch_device().set_rng_state(self.gen_random_states)
104+
loop.run_until_complete(self.wake_up())
145105

146106
@GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger)
147107
def __exit__(self, exc_type, exc_value, traceback):
148-
if self.rollout_config.free_cache_engine:
149-
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
150-
loop = asyncio.get_event_loop()
151-
loop.run_until_complete(self.release_memory())
152-
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)
153-
154-
self.module.train()
155-
156-
# add empty cache after each compute
157-
get_torch_device().empty_cache()
158-
159-
# restore random states
160-
if self.device_mesh is not None:
161-
self.gen_random_states = get_torch_device().get_rng_state()
162-
get_torch_device().set_rng_state(self.torch_random_states)
108+
loop = asyncio.get_event_loop()
109+
loop.run_until_complete(self.sleep())
163110

164111
async def update_weights(self, params):
165112
# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
@@ -207,6 +154,15 @@ async def wake_up(self):
207154
params = {
208155
k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()
209156
}
157+
158+
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
159+
if self.multi_stage_wake_up:
160+
await self.inference_engine.resume_memory_occupation(tags=["weights"])
161+
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
162+
else:
163+
await self.inference_engine.resume_memory_occupation()
164+
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
165+
210166
# Copy, not share memory
211167
await self.update_weights(params)
212168
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
@@ -217,6 +173,10 @@ async def wake_up(self):
217173
get_torch_device().empty_cache()
218174
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
219175

176+
if self.multi_stage_wake_up and self.rollout_config.free_cache_engine:
177+
await self.inference_engine.resume_memory_occupation(tags=["kv_cache"])
178+
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)
179+
220180
# important: need to manually set the random states of each tp to be identical.
221181
if self.device_mesh is not None:
222182
self.torch_random_states = get_torch_device().get_rng_state()

verl/workers/sharding_manager/megatron_sglang.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -114,45 +114,13 @@ def __init__(
114114
def __enter__(self):
115115
self.timing = {}
116116
with simple_timer("reshard", self.timing):
117-
if self.offload_param:
118-
load_megatron_model_to_gpu(self.actor_module)
119-
if self.bridge is not None:
120-
per_tensor_param = self.bridge.export_weights(self.actor_module)
121-
else:
122-
per_tensor_param = per_tensor_generator(
123-
self.actor_module,
124-
self.model_config,
125-
self.weight_converter,
126-
self.transformer_config,
127-
self.layer_name_mapping,
128-
)
129117
loop = asyncio.get_event_loop()
130-
loop.run_until_complete(self.update_weights(per_tensor_param))
131-
if self.offload_param:
132-
offload_megatron_model_to_cpu(self.actor_module)
133-
get_torch_device().empty_cache()
134-
# important: need to manually set the random states of each tp to be identical.
135-
if self.device_mesh is not None:
136-
self.torch_random_states = get_torch_device().get_rng_state()
137-
get_torch_device().set_rng_state(self.gen_random_states)
118+
loop.run_until_complete(self.wake_up())
138119

139120
@GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger)
140121
def __exit__(self, exc_type, exc_value, traceback):
141-
if self.rollout_config.free_cache_engine:
142-
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
143-
loop = asyncio.get_event_loop()
144-
loop.run_until_complete(self.release_memory())
145-
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)
146-
147-
for model in self.actor_module:
148-
model.train()
149-
# add empty cache after each compute
150-
get_torch_device().empty_cache()
151-
152-
# restore random states
153-
if self.device_mesh is not None:
154-
self.gen_random_states = get_torch_device().get_rng_state()
155-
get_torch_device().set_rng_state(self.torch_random_states)
122+
loop = asyncio.get_event_loop()
123+
loop.run_until_complete(self.sleep())
156124

157125
async def update_weights(self, params):
158126
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
@@ -182,8 +150,10 @@ async def release_memory(self):
182150
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
183151
await self.inference_engine.release_memory_occupation()
184152

185-
@GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger)
153+
@GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger)
186154
async def wake_up(self):
155+
if self.offload_param:
156+
load_megatron_model_to_gpu(self.actor_module)
187157
if self.bridge is not None:
188158
per_tensor_param = self.bridge.export_weights(self.actor_module)
189159
else:
@@ -195,12 +165,15 @@ async def wake_up(self):
195165
self.layer_name_mapping,
196166
)
197167
await self.update_weights(per_tensor_param)
168+
if self.offload_param:
169+
offload_megatron_model_to_cpu(self.actor_module)
170+
get_torch_device().empty_cache()
198171
# important: need to manually set the random states of each tp to be identical.
199172
if self.device_mesh is not None:
200173
self.torch_random_states = get_torch_device().get_rng_state()
201174
get_torch_device().set_rng_state(self.gen_random_states)
202175

203-
@GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger)
176+
@GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger)
204177
async def sleep(self):
205178
if self.rollout_config.free_cache_engine:
206179
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)

0 commit comments

Comments
 (0)