Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/grpo_trainer/run_qwen2_5_vl-7b-sglang.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ python3 -m verl.trainer.main_ppo \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_7b_function_rm' \
actor_rollout_ref.ref.strategy=fsdp2 \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
reward_model.strategy=fsdp2 \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
Expand Down
4 changes: 4 additions & 0 deletions examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ python3 -m verl.trainer.main_ppo \
trainer.save_freq=-1 \
trainer.test_freq=20 \
trainer.val_before_train=True \
actor_rollout_ref.ref.strategy=fsdp2 \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
reward_model.strategy=fsdp2 \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
Expand Down
9 changes: 6 additions & 3 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from verl.protocol import all_gather_data_proto
from verl.utils.device import get_device_id, get_torch_device, set_expandable_segments
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from verl.utils.torch_functional import check_device_is_available
Expand Down Expand Up @@ -124,7 +125,7 @@ async def release_memory(self):

@GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger)
async def wake_up(self):
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While using a more aggressive cache clearing function is a good approach to mitigate OOM errors, the implementation of aggressive_empty_cache in verl/utils/memory_utils.py appears to have a potential flaw in its loop termination logic.

The loop terminates if the amount of freed reserved memory in an iteration is less than 1GB (reserved_freed < 1024**3). This has two potential drawbacks:

  1. Ineffectiveness: The loop might terminate prematurely if it frees just under 1GB of memory, even if more could be freed in subsequent iterations. This could lead to the OOM error persisting in some cases.
  2. Performance: If more than 1GB is freed, the loop continues, potentially for up to max_retries times. This could introduce unnecessary latency if a single pass was sufficient.

A more robust approach would be to continue looping as long as a significant amount of memory is being freed, and stop only when an iteration frees little to no memory. This would make the function both more effective at preventing OOMs and more efficient.

Since aggressive_empty_cache is a new core utility for this fix, it would be beneficial to refine its logic to be more reliable before merging.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also modify megatron_sglang.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me do it


log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
Expand All @@ -146,6 +147,8 @@ async def wake_up(self):

# sglang need to set _set_allocator_settings to False
logger.debug("fsdp sglang sharding_manager _set_allocator_settings to False")
# Note(chenyang): SGLang is using torch memory pool to manage memory
# which is incompatible with expandable segments
set_expandable_segments(False)

if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
Expand All @@ -161,7 +164,7 @@ async def wake_up(self):
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)

del params
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

if (
Expand All @@ -187,7 +190,7 @@ async def sleep(self):
self.module.train()

# add empty cache after each compute
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)

# always set _set_allocator_settings to True when using sglang
# it is required by fsdp2 to avoid oom
Expand Down
7 changes: 5 additions & 2 deletions verl/workers/sharding_manager/megatron_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
offload_megatron_model_to_cpu,
per_tensor_generator,
)
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets

Expand Down Expand Up @@ -163,6 +164,8 @@ async def release_memory(self):

@GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger)
async def wake_up(self):
aggressive_empty_cache(force_sync=True)

if self.offload_param:
load_megatron_model_to_gpu(self.actor_module, load_grad=False)
if self.bridge is not None:
Expand All @@ -178,7 +181,7 @@ async def wake_up(self):
await self.update_weights(per_tensor_param)
if self.offload_param:
offload_megatron_model_to_cpu(self.actor_module)
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
Expand All @@ -194,7 +197,7 @@ async def sleep(self):
for model in self.actor_module:
model.train()
# add empty cache after each compute
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)

# restore random states
if self.device_mesh is not None:
Expand Down
Loading