Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ actor_rollout_ref:
# number of responses (i.e. num sample times). > 1 for grpo
n: 1

# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache)
multi_stage_wake_up: false

# Extra inference engine arguments (vllm, sglang).
engine_kwargs:

Expand Down
1 change: 1 addition & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def _build_rollout(self, trust_remote_code=False):
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
multi_stage_wake_up=self.config.rollout.multi_stage_wake_up,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)

Expand Down
12 changes: 7 additions & 5 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,23 @@ def __init__(self, **kwargs):
# default to use dummy load format, which need to reload weights in first time
self._need_reload = True

async def release_memory_occupation(self):
async def release_memory_occupation(self, tags: Optional[list[str]] = None):
"""Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput()
print(f"release_memory_occupation with tags: {tags}")
obj = ReleaseMemoryOccupationReqInput(tags=tags)
return await self.tokenizer_manager.release_memory_occupation(obj, None)

async def resume_memory_occupation(self):
async def resume_memory_occupation(self, tags: Optional[list[str]] = None):
"""Resume GPU occupation."""

print(f"resume_memory_occupation with tags: {tags}")
# because __init__ is a sync method, it can not call the async release_memory_occupation
# have to move release_memory_occupation from __init__ to here
# For multi-stage awake, we run release weight and kv_cache when we resume weights for the first time.
if self._need_reload:
await self.release_memory_occupation()
self._need_reload = False

obj = ResumeMemoryOccupationReqInput()
obj = ResumeMemoryOccupationReqInput(tags=tags)
return await self.tokenizer_manager.resume_memory_occupation(obj, None)

async def update_weights_from_tensor(
Expand Down
25 changes: 20 additions & 5 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ def __init__(
full_params: bool = False,
device_mesh: DeviceMesh = None,
offload_param: bool = False,
multi_stage_wake_up: bool = False,
):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
self.device_mesh = device_mesh
self.offload_param = offload_param
self.multi_stage_wake_up = multi_stage_wake_up

# Full params
self.full_params = full_params
Expand Down Expand Up @@ -96,6 +98,15 @@ def __enter__(self):
self.timing = {}
with _timer("reshard", self.timing):
torch.cuda.empty_cache()
loop = asyncio.get_event_loop()

if self.device_mesh["infer_tp"].get_local_rank() == 0:
if self.multi_stage_wake_up:
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"]))
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
loop.run_until_complete(self.inference_engine.resume_memory_occupation())
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_fsdp_model_to_gpu(self.module)
Expand All @@ -105,7 +116,6 @@ def __enter__(self):
params = {k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()}
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
# Copy, not share memory
loop = asyncio.get_event_loop()
loop.run_until_complete(self.update_weights(params))
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)

Expand All @@ -115,6 +125,10 @@ def __enter__(self):
torch.cuda.empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

if self.multi_stage_wake_up:
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"]))
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)

# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
Expand All @@ -138,9 +152,6 @@ def __exit__(self, exc_type, exc_value, traceback):
torch.cuda.set_rng_state(self.torch_random_states)

async def update_weights(self, params):
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.resume_memory_occupation()

# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
named_tensors = [(k, v) for k, v in params.items()]
load_format = None
Expand Down Expand Up @@ -172,7 +183,11 @@ async def update_weights(self, params):

async def release_memory(self):
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.release_memory_occupation()
if self.multi_stage_wake_up:
await self.inference_engine.release_memory_occupation(tags=["kv_cache"])
await self.inference_engine.release_memory_occupation(tags=["weights"])
else:
await self.inference_engine.release_memory_occupation()

@GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger)
async def wake_up(self):
Expand Down