-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[megatron] feat: set_expandable_segments for megatron #3181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ | |
| from verl.protocol import all_gather_data_proto | ||
| from verl.third_party.vllm import LLM, VLLM_SLEEP_LEVEL | ||
| from verl.third_party.vllm import parallel_state as vllm_ps | ||
| from verl.utils.device import get_torch_device | ||
| from verl.utils.device import get_torch_device, set_expandable_segments | ||
| from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator | ||
| from verl.utils.memory_utils import aggressive_empty_cache | ||
| from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage | ||
|
|
@@ -149,6 +149,8 @@ def __enter__(self): | |
| if self.offload_param: | ||
| load_megatron_model_to_gpu(self.actor_module, load_grad=False) | ||
|
|
||
| set_expandable_segments(False) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This call to |
||
|
|
||
| if self.rollout_config.free_cache_engine: | ||
| if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: | ||
| self.inference_engine.wake_up(tags=["weights"]) | ||
|
|
@@ -196,6 +198,8 @@ def __exit__(self, exc_type, exc_value, traceback): | |
|
|
||
| aggressive_empty_cache(force_sync=True) | ||
|
|
||
| set_expandable_segments(True) | ||
|
|
||
| # restore random states | ||
| if self.device_mesh is not None: | ||
| self.gen_random_states = get_torch_device().get_rng_state() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
set_expandable_segmentsfunction usestorch.cuda.memory._set_allocator_settings, which is an internal PyTorch API. Relying on such private APIs is risky as they are not guaranteed to be stable and can be changed or removed without notice in future PyTorch releases. This could lead to unexpected failures. To make this more robust, I recommend that theset_expandable_segmentsfunction itself is modified to include error handling, such as atry-except AttributeErrorblock, to gracefully handle cases where this internal API is no longer available.