Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions verl/models/mcore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,47 @@ def preprocess_packed_seqs(

pad_size = (align_size - seqlens_in_batch % align_size) % align_size
seqlens_in_batch_padded = seqlens_in_batch + pad_size

cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
max_seqlen_in_batch = seqlens_in_batch_padded.max().item()

# ----------------------------------------------------------------------------
# Move the index information needed in the subsequent loop to the CPU at once,
# to avoid frequent .item() calls in the loop that cause D2H synchronization
# ----------------------------------------------------------------------------
seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths
seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding
cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding)

# Pure Python int calculation to avoid further synchronization
max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu)

shape = list(input_ids.shape[1:])
shape[0] = seqlens_in_batch_padded.sum().item() // cp_size
shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size
if pre_process:
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
for i in range(batch_size):
# Use Python int, so no GPU→CPU sync in the loop
if cp_size <= 1:
seqlen = seqlens_in_batch[i]
input_ids_rmpad[cu_seqlens_padded[i] : cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]
seqlen = seqlens_in_batch_cpu[i]
start_idx = cu_seqlens_padded_cpu[i]
input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
continue
seqlen = seqlens_in_batch_padded[i] // cp_size

seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
seqlen = seqlen_padded_i // cp_size
half_seqlen = seqlen // 2
start_idx = cu_seqlens_padded[i] // cp_size
start_idx = cu_seqlens_padded_cpu[i] // cp_size
# split to 2 chunks
d = input_ids[i, attention_mask[i]]
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
]

remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1)
remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank
remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)
remain_end = seqlen_padded_i - half_seqlen * cp_rank
remain_end = min(remain_end, d.shape[0])
remain_len = remain_end - remain_start
if remain_len > 0:
Expand Down Expand Up @@ -100,6 +115,14 @@ def postprocess_packed_seqs(
"""
if not post_process:
return output

# -------------------------------------------------------------------------
# Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance,
# to avoid a large number of .item() calls in the loop
# -------------------------------------------------------------------------
cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist()
seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist()

shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim
output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)

Expand All @@ -115,22 +138,19 @@ def postprocess_packed_seqs(
output_list = [output]
for i in range(batch_size):
if cp_size <= 1:
s = attention_mask[i].sum().item()
output_new[i, attention_mask[i]] = output[0][
packed_seq_params.cu_seqlens_q_padded[i] : packed_seq_params.cu_seqlens_q_padded[i] + s
]
s = seq_lens_cpu[i]
start_idx = cu_padded_cpu[i]
output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s]
continue
s_len_padded_chunk = (
packed_seq_params.cu_seqlens_q_padded[i + 1] - packed_seq_params.cu_seqlens_q_padded[i]
) // cp_size
s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size
half_seqlen = s_len_padded_chunk // 2
s_len = attention_mask[i].sum().item()
s_len = seq_lens_cpu[i]
s_len_padded = s_len_padded_chunk * cp_size
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
for j in range(cp_size):
o = output_list[j][0]
# split to 2 chunks
packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size
packed_start_idx = cu_padded_cpu[i] // cp_size
o0, o1 = (
o[packed_start_idx : packed_start_idx + half_seqlen],
o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk],
Expand Down
137 changes: 137 additions & 0 deletions verl/utils/memory_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import logging

import torch

from verl.utils.device import get_torch_device

logger = logging.getLogger(__name__)


def aggressive_empty_cache(force_sync: bool = True, max_retries: int = 3) -> None:
"""
More aggressive GPU memory cleanup function, tries to release PyTorch reserved but unallocated memory.

Args:
force_sync: Whether to force CUDA synchronization
max_retries: Maximum number of retries
"""
if not torch.cuda.is_available():
return

device = get_torch_device()

for attempt in range(max_retries):
# Record memory status before cleanup
before_reserved = device.memory_reserved(0)
before_allocated = device.memory_allocated(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding device index 0 for memory statistics can lead to incorrect behavior in multi-GPU environments where the current device is not GPU 0. The memory functions should operate on the current device. The device.memory_reserved() and device.memory_allocated() functions default to the current device when no device index is provided. Please remove the hardcoded 0.

Suggested change
before_reserved = device.memory_reserved(0)
before_allocated = device.memory_allocated(0)
before_reserved = device.memory_reserved()
before_allocated = device.memory_allocated()


# Run garbage collection
gc.collect()

# Clear PyTorch cache
device.empty_cache()

# Force synchronization (optional)
if force_sync:
device.synchronize()

# Record memory status after cleanup
after_reserved = device.memory_reserved(0)
after_allocated = device.memory_allocated(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous point, the device index 0 is hardcoded here. This should be removed to ensure the function queries memory statistics for the current device.

Suggested change
after_reserved = device.memory_reserved(0)
after_allocated = device.memory_allocated(0)
after_reserved = device.memory_reserved()
after_allocated = device.memory_allocated()


# Calculate freed memory
reserved_freed = before_reserved - after_reserved
allocated_freed = before_allocated - after_allocated

logger.info(
f"Memory cleanup attempt {attempt + 1}: Freed {reserved_freed / 1024**3:.2f} GB reserved, "
f"{allocated_freed / 1024**3:.2f} GB allocated"
)

# Stop retrying if little memory was freed
if reserved_freed < 1024**3: # less than 1GB
break


def reset_memory_stats() -> None:
"""Reset GPU memory statistics"""
if torch.cuda.is_available():
device = get_torch_device()
device.reset_peak_memory_stats()
device.reset_accumulated_memory_stats()


def get_memory_info() -> dict:
"""Get detailed GPU memory information"""
if not torch.cuda.is_available():
return {}

device = get_torch_device()

return {
"total_memory_gb": device.get_device_properties(0).total_memory / 1024**3,
"reserved_memory_gb": device.memory_reserved(0) / 1024**3,
"allocated_memory_gb": device.memory_allocated(0) / 1024**3,
"cached_memory_gb": (device.memory_reserved(0) - device.memory_allocated(0)) / 1024**3,
"max_memory_allocated_gb": device.max_memory_allocated(0) / 1024**3,
"max_memory_reserved_gb": device.max_memory_reserved(0) / 1024**3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function hardcodes the device index 0 for all memory statistics. This will report incorrect information on multi-GPU systems where the current process is not on device 0.

Most torch.cuda memory functions default to the current device if no index is provided. For get_device_properties, you need to explicitly pass the current device index, which you can get via device.current_device().

    device = get_torch_device()
    device_id = device.current_device()

    return {
        "total_memory_gb": device.get_device_properties(device_id).total_memory / 1024**3,
        "reserved_memory_gb": device.memory_reserved() / 1024**3,
        "allocated_memory_gb": device.memory_allocated() / 1024**3,
        "cached_memory_gb": (device.memory_reserved() - device.memory_allocated()) / 1024**3,
        "max_memory_allocated_gb": device.max_memory_allocated() / 1024**3,
        "max_memory_reserved_gb": device.max_memory_reserved() / 1024**3,
    }

}


def log_memory_usage(stage: str = "current") -> None:
"""Log GPU memory usage"""
if not torch.cuda.is_available():
return

info = get_memory_info()
logger.info(
f"Memory usage [{stage}]: "
f"Total: {info['total_memory_gb']:.2f} GB, "
f"Allocated: {info['allocated_memory_gb']:.2f} GB, "
f"Reserved: {info['reserved_memory_gb']:.2f} GB, "
f"Cached: {info['cached_memory_gb']:.2f} GB"
)


def optimize_memory_for_inference() -> None:
"""Optimize GPU memory usage for inference"""
if not torch.cuda.is_available():
return

# Set a more aggressive memory allocation policy
torch.cuda.set_per_process_memory_fraction(0.95) # Use 95% of GPU memory

# Clear cache
aggressive_empty_cache(force_sync=True)

logger.info("Optimized GPU memory usage for inference")


def optimize_memory_for_training() -> None:
"""Optimize GPU memory usage for training"""
if not torch.cuda.is_available():
return

# Set a moderate memory allocation policy
torch.cuda.set_per_process_memory_fraction(0.9) # Use 90% of GPU memory

# Clear cache
aggressive_empty_cache(force_sync=False)

logger.info("Optimized GPU memory usage for training")
30 changes: 18 additions & 12 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
offload_megatron_model_to_cpu,
offload_megatron_optimizer,
)
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import get_hf_model_path, load_mcore_dist_weights, load_megatron_gptmodel_weights
from verl.utils.profiler import (
DistProfiler,
Expand Down Expand Up @@ -176,6 +177,10 @@ def _build_model_optimizer(self, model_path, optim_config, override_model_config
)
self.generation_config = get_generation_config(self.local_path)

override_ddp_config = OmegaConf.to_container(
self.config.actor.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
)

def make_model(wrap_with_ddp=False):
if self.bridge is not None:
from verl.models.mcore.mbridge import freeze_moe_router
Expand All @@ -184,7 +189,9 @@ def make_model(wrap_with_ddp=False):
if override_model_config.get("moe_config", {}).get("freeze_moe_router", False):
post_model_creation_callbacks.append(freeze_moe_router)
return self.bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=wrap_with_ddp
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=wrap_with_ddp,
optimizer_config=override_ddp_config,
)
else:

Expand All @@ -203,9 +210,6 @@ def megatron_actor_model_provider(pre_process, post_process):
parallel_model.to(get_device_name())
return parallel_model

override_ddp_config = OmegaConf.to_container(
self.config.actor.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
)
return get_model(
megatron_actor_model_provider,
wrap_with_ddp=wrap_with_ddp,
Expand Down Expand Up @@ -536,7 +540,7 @@ def update_actor(self, data: DataProto):
offload_megatron_optimizer(self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during update_actor", logger=logger)

get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down Expand Up @@ -573,7 +577,7 @@ def generate_sequences(self, prompts: DataProto):
output.meta_info["timing"] = timing_generate
output = output.to("cpu")
# clear kv cache
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)
return output

@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
Expand All @@ -596,7 +600,7 @@ def compute_ref_log_prob(self, data: DataProto):
if self._ref_is_offload_param:
offload_megatron_model_to_cpu(self.ref_module)
log_gpu_memory_usage("After offload ref params and grad during compute_ref_log_prob", logger=logger)
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)
return output

@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
Expand All @@ -623,7 +627,7 @@ def compute_log_prob(self, data: DataProto):
if self._is_offload_param:
offload_megatron_model_to_cpu(self.actor_module)
log_gpu_memory_usage("After offload actor params and grad during compute_log_prob", logger=logger)
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
Expand Down Expand Up @@ -784,14 +788,19 @@ def _build_critic_model_optimizer(
self.config.megatron.use_mbridge,
)

override_ddp_config = OmegaConf.to_container(
self.config.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
)
if self.bridge is not None:
from verl.models.mcore.mbridge import freeze_moe_router, make_value_model

post_model_creation_callbacks = [make_value_model]
if override_model_config.get("moe_config", {}).get("freeze_moe_router", False):
post_model_creation_callbacks.append(freeze_moe_router)
critic_module = self.bridge.get_model(
post_model_creation_callbacks=post_model_creation_callbacks, wrap_with_ddp=True
post_model_creation_callbacks=post_model_creation_callbacks,
wrap_with_ddp=True,
optimizer_config=override_ddp_config,
)
else:

Expand All @@ -810,9 +819,6 @@ def megatron_critic_model_provider(pre_process, post_process):
parallel_model.to(get_device_name())
return parallel_model

override_ddp_config = OmegaConf.to_container(
self.config.megatron.get("override_ddp_config", OmegaConf.create()), resolve=True
)
# Step 3: initialize the megatron model
critic_module = get_model(
model_provider_func=megatron_critic_model_provider,
Expand Down
9 changes: 5 additions & 4 deletions verl/workers/sharding_manager/megatron_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from verl.third_party.vllm import parallel_state as vllm_ps
from verl.utils.device import get_torch_device
from verl.utils.megatron_utils import load_megatron_model_to_gpu, offload_megatron_model_to_cpu, per_tensor_generator
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage
from verl.utils.profiler.performance import simple_timer
from verl.utils.torch_functional import check_device_is_available
Expand Down Expand Up @@ -143,11 +144,11 @@ def __init__(
def __enter__(self):
self.timing = {}
with simple_timer("reshard", self.timing):
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)

log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_megatron_model_to_gpu(self.actor_module)
load_megatron_model_to_gpu(self.actor_module, load_grad=False)

if self.rollout_config.free_cache_engine:
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
Expand All @@ -172,7 +173,7 @@ def __enter__(self):

if self.offload_param:
offload_megatron_model_to_cpu(self.actor_module)
get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)

if (
self.rollout_config.free_cache_engine
Expand All @@ -192,7 +193,7 @@ def __exit__(self, exc_type, exc_value, traceback):
for model in self.actor_module:
model.train()

get_torch_device().empty_cache()
aggressive_empty_cache(force_sync=True)

# restore random states
if self.device_mesh is not None:
Expand Down
Loading