-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[megatron] feat: a bunch of optimzation on vram, sequence packing #2678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
a492af2
29cb9c2
0ea3278
7b5202c
1bb30da
f87b89e
f220c9b
5488912
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,137 @@ | ||||||||||
| # Copyright 2025 Bytedance Ltd. and/or its affiliates | ||||||||||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||||||||||
| # | ||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||
| # You may obtain a copy of the License at | ||||||||||
| # | ||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||
| # | ||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||
| # See the License for the specific language governing permissions and | ||||||||||
| # limitations under the License. | ||||||||||
|
|
||||||||||
| import gc | ||||||||||
| import logging | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
|
|
||||||||||
| from verl.utils.device import get_torch_device | ||||||||||
|
|
||||||||||
| logger = logging.getLogger(__name__) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None: | ||||||||||
| """ | ||||||||||
| More aggressive GPU memory cleanup function, tries to release PyTorch reserved but unallocated memory. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| force_sync: Whether to force CUDA synchronization | ||||||||||
| max_retries: Maximum number of retries | ||||||||||
| """ | ||||||||||
| if not torch.cuda.is_available(): | ||||||||||
| return | ||||||||||
|
|
||||||||||
| device = get_torch_device() | ||||||||||
|
|
||||||||||
| for attempt in range(max_retries): | ||||||||||
| # Record memory status before cleanup | ||||||||||
| before_reserved = device.memory_reserved(0) | ||||||||||
| before_allocated = device.memory_allocated(0) | ||||||||||
|
|
||||||||||
| # Run garbage collection | ||||||||||
| gc.collect() | ||||||||||
|
|
||||||||||
| # Clear PyTorch cache | ||||||||||
| device.empty_cache() | ||||||||||
|
|
||||||||||
| # Force synchronization (optional) | ||||||||||
| if force_sync: | ||||||||||
| device.synchronize() | ||||||||||
|
|
||||||||||
| # Record memory status after cleanup | ||||||||||
| after_reserved = device.memory_reserved(0) | ||||||||||
| after_allocated = device.memory_allocated(0) | ||||||||||
|
||||||||||
| after_reserved = device.memory_reserved(0) | |
| after_allocated = device.memory_allocated(0) | |
| after_reserved = device.memory_reserved() | |
| after_allocated = device.memory_allocated() |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function hardcodes the device index 0 for all memory statistics. This will report incorrect information on multi-GPU systems where the current process is not on device 0.
Most torch.cuda memory functions default to the current device if no index is provided. For get_device_properties, you need to explicitly pass the current device index, which you can get via device.current_device().
device = get_torch_device()
device_id = device.current_device()
return {
"total_memory_gb": device.get_device_properties(device_id).total_memory / 1024**3,
"reserved_memory_gb": device.memory_reserved() / 1024**3,
"allocated_memory_gb": device.memory_allocated() / 1024**3,
"cached_memory_gb": (device.memory_reserved() - device.memory_allocated()) / 1024**3,
"max_memory_allocated_gb": device.max_memory_allocated() / 1024**3,
"max_memory_reserved_gb": device.max_memory_reserved() / 1024**3,
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding device index
0for memory statistics can lead to incorrect behavior in multi-GPU environments where the current device is not GPU 0. The memory functions should operate on the current device. Thedevice.memory_reserved()anddevice.memory_allocated()functions default to the current device when no device index is provided. Please remove the hardcoded0.