Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion verl/workers/sharding_manager/megatron_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.distributed.device_mesh import DeviceMesh

from verl.protocol import DataProto, all_gather_data_proto
from verl.utils.device import get_torch_device
from verl.utils.device import get_torch_device, set_expandable_segments
from verl.utils.megatron_utils import (
load_megatron_model_to_gpu,
offload_megatron_model_to_cpu,
Expand Down Expand Up @@ -178,6 +178,9 @@ async def wake_up(self):
self.transformer_config,
self.layer_name_mapping,
)

set_expandable_segments(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The set_expandable_segments function uses torch.cuda.memory._set_allocator_settings, which is an internal PyTorch API. Relying on such private APIs is risky as they are not guaranteed to be stable and can be changed or removed without notice in future PyTorch releases. This could lead to unexpected failures. To make this more robust, I recommend that the set_expandable_segments function itself is modified to include error handling, such as a try-except AttributeError block, to gracefully handle cases where this internal API is no longer available.


await self.update_weights(per_tensor_param)
if self.offload_param:
offload_megatron_model_to_cpu(self.actor_module)
Expand All @@ -199,6 +202,8 @@ async def sleep(self):
# add empty cache after each compute
aggressive_empty_cache(force_sync=True)

set_expandable_segments(True)

# restore random states
if self.device_mesh is not None:
self.gen_random_states = get_torch_device().get_rng_state()
Expand Down
6 changes: 5 additions & 1 deletion verl/workers/sharding_manager/megatron_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from verl.protocol import all_gather_data_proto
from verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL
from verl.third_party.vllm import parallel_state as vllm_ps
from verl.utils.device import get_torch_device
from verl.utils.device import get_torch_device, set_expandable_segments
from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage
Expand Down Expand Up @@ -149,6 +149,8 @@ def __enter__(self):
if self.offload_param:
load_megatron_model_to_gpu(self.actor_module, load_grad=False)

set_expandable_segments(False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This call to set_expandable_segments depends on an internal PyTorch function (_set_allocator_settings). This creates a dependency on an unstable, private API that may break in future PyTorch versions. While it addresses memory fragmentation, it introduces a maintainability risk. It would be safer if the set_expandable_segments implementation were updated to catch potential AttributeError and log a warning if the API is not found. This would prevent the application from crashing due to changes in PyTorch internals.


if self.rollout_config.free_cache_engine:
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["weights"])
Expand Down Expand Up @@ -196,6 +198,8 @@ def __exit__(self, exc_type, exc_value, traceback):

aggressive_empty_cache(force_sync=True)

set_expandable_segments(True)

# restore random states
if self.device_mesh is not None:
self.gen_random_states = get_torch_device().get_rng_state()
Expand Down