Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ actor_rollout_ref:
# number of responses (i.e. num sample times). > 1 for grpo
n: 1

# Whether to wake up inference engine in multi-stage. (Wake up model weights first, then resume kv cache)
multi_stage_wake_up: false

# Extra inference engine arguments (vllm, sglang).
engine_kwargs:

Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def fit(self):

# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True) and False:
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
Expand Down
1 change: 1 addition & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def _build_rollout(self, trust_remote_code=False):
full_params="hf" in self.config.rollout.load_format,
device_mesh=rollout_device_mesh,
offload_param=self._is_offload_param,
multi_stage_wake_up=self.config.rollout.multi_stage_wake_up,
)
log_gpu_memory_usage("After building sharding manager", logger=logger)

Expand Down
12 changes: 7 additions & 5 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Set prometheus env vars
if server_args.enable_metrics:
Expand Down Expand Up @@ -139,21 +140,22 @@ def __init__(self, **kwargs):
# default to use dummy load format, which need to reload weights in first time
self._need_reload = True

async def release_memory_occupation(self):
async def release_memory_occupation(self, tags: Optional[List[str]] = None):
"""Release GPU occupation temporarily."""
obj = ReleaseMemoryOccupationReqInput()
print(f"release_memory_occupation with tags: {tags}")
obj = ReleaseMemoryOccupationReqInput(tags=tags)
return await self.tokenizer_manager.release_memory_occupation(obj, None)

async def resume_memory_occupation(self):
async def resume_memory_occupation(self, tags: Optional[List[str]] = None):
"""Resume GPU occupation."""

print(f"resume_memory_occupation with tags: {tags}")
# because __init__ is a sync method, it can not call the async release_memory_occupation
# have to move release_memory_occupation from __init__ to here
if self._need_reload:
await self.release_memory_occupation()
self._need_reload = False

obj = ResumeMemoryOccupationReqInput()
obj = ResumeMemoryOccupationReqInput(tags=tags)
return await self.tokenizer_manager.resume_memory_occupation(obj, None)

async def update_weights_from_tensor(
Expand Down
21 changes: 19 additions & 2 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ def __init__(
full_params: bool = False,
device_mesh: DeviceMesh = None,
offload_param: bool = False,
multi_stage_wake_up: bool = False,
):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
self.device_mesh = device_mesh
self.offload_param = offload_param
self.multi_stage_wake_up = multi_stage_wake_up

# Full params
self.full_params = full_params
Expand Down Expand Up @@ -115,6 +117,9 @@ def __enter__(self):
torch.cuda.empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

if self.multi_stage_wake_up:
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"]))

# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
Expand All @@ -138,8 +143,15 @@ def __exit__(self, exc_type, exc_value, traceback):
torch.cuda.set_rng_state(self.torch_random_states)

async def update_weights(self, params):
log_gpu_memory_usage("Before resume_memory_occupation in update_weights", logger=logger)
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.resume_memory_occupation()
# await self.inference_engine.resume_memory_occupation()
if self.multi_stage_wake_up:
await self.inference_engine.resume_memory_occupation(tags=["weights"])
else:
await self.inference_engine.resume_memory_occupation()

log_gpu_memory_usage("After resume_memory_occupation in update_weights", logger=logger)

# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
named_tensors = [(k, v) for k, v in params.items()]
Expand Down Expand Up @@ -172,7 +184,12 @@ async def update_weights(self, params):

async def release_memory(self):
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self.inference_engine.release_memory_occupation()
if self.multi_stage_wake_up:
print("release_memory_occupation kv_cache")
await self.inference_engine.release_memory_occupation(tags=["kv_cache"])
await self.inference_engine.release_memory_occupation(tags=["weights"])
else:
await self.inference_engine.release_memory_occupation()

@GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger)
async def wake_up(self):
Expand Down